MLIR  18.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/ErrorHandling.h"
33 #include "llvm/Support/StringSaver.h"
34 
35 using namespace mlir;
36 using namespace mlir::gpu;
37 
38 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // GPU Device Mapping Attributes
42 //===----------------------------------------------------------------------===//
43 
44 int64_t GPUBlockMappingAttr::getMappingId() const {
45  return static_cast<int64_t>(getBlock());
46 }
47 
48 bool GPUBlockMappingAttr::isLinearMapping() const {
49  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
50 }
51 
52 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
53  return isLinearMapping()
54  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
55  : getMappingId();
56 }
57 
58 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
59  return static_cast<int64_t>(getWarpgroup());
60 }
61 
62 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
63  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
64 }
65 
66 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
67  return isLinearMapping()
68  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
69  : getMappingId();
70 }
71 
72 int64_t GPUWarpMappingAttr::getMappingId() const {
73  return static_cast<int64_t>(getWarp());
74 }
75 
76 bool GPUWarpMappingAttr::isLinearMapping() const {
77  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
78 }
79 
80 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
81  return isLinearMapping()
82  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
83  : getMappingId();
84 }
85 
86 int64_t GPUThreadMappingAttr::getMappingId() const {
87  return static_cast<int64_t>(getThread());
88 }
89 
90 bool GPUThreadMappingAttr::isLinearMapping() const {
91  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
92 }
93 
94 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
95  return isLinearMapping()
96  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
97  : getMappingId();
98 }
99 
100 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
101  return static_cast<int64_t>(getAddressSpace());
102 }
103 
104 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
105  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
106 }
107 
108 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
109  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // MMAMatrixType
114 //===----------------------------------------------------------------------===//
115 
117  StringRef operand) {
118  return Base::get(elementType.getContext(), shape, elementType, operand);
119 }
120 
123  ArrayRef<int64_t> shape, Type elementType,
124  StringRef operand) {
125  return Base::getChecked(emitError, elementType.getContext(), shape,
126  elementType, operand);
127 }
128 
129 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
130 
132  return getImpl()->getShape();
133 }
134 
135 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
136 
137 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
138 
140  return elementType.isF16() || elementType.isF32() ||
141  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
142  elementType.isInteger(32);
143 }
144 
147  ArrayRef<int64_t> shape, Type elementType,
148  StringRef operand) {
149  if (!operand.equals("AOp") && !operand.equals("BOp") &&
150  !operand.equals("COp"))
151  return emitError() << "operand expected to be one of AOp, BOp or COp";
152 
153  if (shape.size() != 2)
154  return emitError() << "MMAMatrixType must have exactly two dimensions";
155 
156  if (!MMAMatrixType::isValidElementType(elementType))
157  return emitError()
158  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
159 
160  return success();
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // GPUDialect
165 //===----------------------------------------------------------------------===//
166 
167 /// GPU memory space identifiers.
169  /// Generic memory space identifier.
171 
172  /// Global memory space identifier.
174 
175  /// Shared memory space identifier.
177 };
178 
179 bool GPUDialect::isKernel(Operation *op) {
180  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
181  return static_cast<bool>(isKernelAttr);
182 }
183 
184 namespace {
185 /// This class defines the interface for handling inlining with gpu
186 /// operations.
187 struct GPUInlinerInterface : public DialectInlinerInterface {
189 
190  /// All gpu dialect ops can be inlined.
191  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
192  return true;
193  }
194 };
195 } // namespace
196 
197 void GPUDialect::initialize() {
198  addTypes<AsyncTokenType>();
199  addTypes<MMAMatrixType>();
200  addTypes<SparseDnTensorHandleType>();
201  addTypes<SparseSpMatHandleType>();
202  addTypes<SparseSpGEMMOpHandleType>();
203  addOperations<
204 #define GET_OP_LIST
205 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
206  >();
207  addAttributes<
208 #define GET_ATTRDEF_LIST
209 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
210  >();
211  addInterfaces<GPUInlinerInterface>();
212 }
213 
214 static std::string getSparseHandleKeyword(SparseHandleKind kind) {
215  switch (kind) {
217  return "sparse.dntensor_handle";
219  return "sparse.spmat_handle";
221  return "sparse.spgemmop_handle";
222  }
223  llvm_unreachable("unknown sparse handle kind");
224  return "";
225 }
226 
228  // Parse the main keyword for the type.
229  StringRef keyword;
230  if (parser.parseKeyword(&keyword))
231  return Type();
232  MLIRContext *context = getContext();
233 
234  // Handle 'async token' types.
235  if (keyword == "async.token")
236  return AsyncTokenType::get(context);
237 
238  if (keyword == "mma_matrix") {
239  SMLoc beginLoc = parser.getNameLoc();
240 
241  // Parse '<'.
242  if (parser.parseLess())
243  return nullptr;
244 
245  // Parse the size and elementType.
246  SmallVector<int64_t> shape;
247  Type elementType;
248  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
249  parser.parseType(elementType))
250  return nullptr;
251 
252  // Parse ','
253  if (parser.parseComma())
254  return nullptr;
255 
256  // Parse operand.
257  std::string operand;
258  if (failed(parser.parseOptionalString(&operand)))
259  return nullptr;
260 
261  // Parse '>'.
262  if (parser.parseGreater())
263  return nullptr;
264 
266  parser.getEncodedSourceLoc(beginLoc)),
267  shape, elementType, operand);
268  }
269 
271  return SparseDnTensorHandleType::get(context);
273  return SparseSpMatHandleType::get(context);
275  return SparseSpGEMMOpHandleType::get(context);
276 
277  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
278  return Type();
279 }
280 // TODO: print refined type here. Notice that should be corresponding to the
281 // parser
282 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
283  TypeSwitch<Type>(type)
284  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
285  .Case<SparseDnTensorHandleType>([&](Type) {
287  })
288  .Case<SparseSpMatHandleType>(
290  .Case<SparseSpGEMMOpHandleType>([&](Type) {
292  })
293  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
294  os << "mma_matrix<";
295  auto shape = fragTy.getShape();
296  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
297  os << *dim << 'x';
298  os << shape.back() << 'x' << fragTy.getElementType();
299  os << ", \"" << fragTy.getOperand() << "\"" << '>';
300  })
301  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
302 }
303 
304 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
305  NamedAttribute attr) {
306  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
307  attr.getName() != getContainerModuleAttrName())
308  return success();
309 
310  auto module = dyn_cast<ModuleOp>(op);
311  if (!module)
312  return op->emitError("expected '")
313  << getContainerModuleAttrName() << "' attribute to be attached to '"
314  << ModuleOp::getOperationName() << '\'';
315 
316  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
317  // Ignore launches that are nested more or less deep than functions in the
318  // module we are currently checking.
319  if (!launchOp->getParentOp() ||
320  launchOp->getParentOp()->getParentOp() != module)
321  return success();
322 
323  // Ignore launch ops with missing attributes here. The errors will be
324  // reported by the verifiers of those ops.
325  if (!launchOp->getAttrOfType<SymbolRefAttr>(
326  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
327  return success();
328 
329  // Check that `launch_func` refers to a well-formed GPU kernel container.
330  StringAttr kernelContainerName = launchOp.getKernelModuleName();
331  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
332  if (!kernelContainer)
333  return launchOp.emitOpError()
334  << "kernel container '" << kernelContainerName.getValue()
335  << "' is undefined";
336 
337  // If the container is a GPU binary op return success.
338  if (isa<BinaryOp>(kernelContainer))
339  return success();
340 
341  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
342  if (!kernelModule)
343  return launchOp.emitOpError()
344  << "kernel module '" << kernelContainerName.getValue()
345  << "' is undefined";
346 
347  // Check that `launch_func` refers to a well-formed kernel function.
348  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
349  if (!kernelFunc)
350  return launchOp.emitOpError("kernel function '")
351  << launchOp.getKernel() << "' is undefined";
352  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
353  if (!kernelConvertedFunction) {
354  InFlightDiagnostic diag = launchOp.emitOpError()
355  << "referenced kernel '" << launchOp.getKernel()
356  << "' is not a function";
357  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
358  return diag;
359  }
360 
361  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
362  GPUDialect::getKernelFuncAttrName()))
363  return launchOp.emitOpError("kernel function is missing the '")
364  << GPUDialect::getKernelFuncAttrName() << "' attribute";
365 
366  // TODO: If the kernel isn't a GPU function (which happens during separate
367  // compilation), do not check type correspondence as it would require the
368  // verifier to be aware of the type conversion.
369  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
370  if (!kernelGPUFunction)
371  return success();
372 
373  unsigned actualNumArguments = launchOp.getNumKernelOperands();
374  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
375  if (expectedNumArguments != actualNumArguments)
376  return launchOp.emitOpError("got ")
377  << actualNumArguments << " kernel operands but expected "
378  << expectedNumArguments;
379 
380  auto functionType = kernelGPUFunction.getFunctionType();
381  for (unsigned i = 0; i < expectedNumArguments; ++i) {
382  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
383  return launchOp.emitOpError("type of function argument ")
384  << i << " does not match";
385  }
386  }
387 
388  return success();
389  });
390 
391  return walkResult.wasInterrupted() ? failure() : success();
392 }
393 
394 /// Parses an optional list of async operands with an optional leading keyword.
395 /// (`async`)? (`[` ssa-id-list `]`)?
396 ///
397 /// This method is used by the tablegen assembly format for async ops as well.
399  OpAsmParser &parser, Type &asyncTokenType,
401  auto loc = parser.getCurrentLocation();
402  if (succeeded(parser.parseOptionalKeyword("async"))) {
403  if (parser.getNumResults() == 0)
404  return parser.emitError(loc, "needs to be named when marked 'async'");
405  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
406  }
407  return parser.parseOperandList(asyncDependencies,
409 }
410 
411 /// Prints optional async dependencies with its leading keyword.
412 /// (`async`)? (`[` ssa-id-list `]`)?
413 // Used by the tablegen assembly format for several async ops.
415  Type asyncTokenType,
416  OperandRange asyncDependencies) {
417  if (asyncTokenType)
418  printer << "async";
419  if (asyncDependencies.empty())
420  return;
421  if (asyncTokenType)
422  printer << ' ';
423  printer << '[';
424  llvm::interleaveComma(asyncDependencies, printer);
425  printer << ']';
426 }
427 
428 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
429 /// Parses a GPU function memory attribution.
430 ///
431 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
432 /// (`private` `(` ssa-id-and-type-list `)`)?
433 ///
434 /// Note that this function parses only one of the two similar parts, with the
435 /// keyword provided as argument.
436 static ParseResult
437 parseAttributions(OpAsmParser &parser, StringRef keyword,
439  // If we could not parse the keyword, just assume empty list and succeed.
440  if (failed(parser.parseOptionalKeyword(keyword)))
441  return success();
442 
444  /*allowType=*/true);
445 }
446 
447 /// Prints a GPU function memory attribution.
448 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
449  ArrayRef<BlockArgument> values) {
450  if (values.empty())
451  return;
452 
453  p << ' ' << keyword << '(';
454  llvm::interleaveComma(
455  values, p, [&p](BlockArgument v) { p << v << " : " << v.getType(); });
456  p << ')';
457 }
458 
459 /// Verifies a GPU function memory attribution.
461  ArrayRef<BlockArgument> attributions,
462  gpu::AddressSpace memorySpace) {
463  for (Value v : attributions) {
464  auto type = llvm::dyn_cast<MemRefType>(v.getType());
465  if (!type)
466  return op->emitOpError() << "expected memref type in attribution";
467 
468  // We can only verify the address space if it hasn't already been lowered
469  // from the AddressSpaceAttr to a target-specific numeric value.
470  auto addressSpace =
471  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
472  if (!addressSpace)
473  continue;
474  if (addressSpace.getValue() != memorySpace)
475  return op->emitOpError()
476  << "expected memory space " << stringifyAddressSpace(memorySpace)
477  << " in attribution";
478  }
479  return success();
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // AllReduceOp
484 //===----------------------------------------------------------------------===//
485 
486 static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
487  Type resType) {
488  return (opName != gpu::AllReduceOperation::AND &&
489  opName != gpu::AllReduceOperation::OR &&
490  opName != gpu::AllReduceOperation::XOR) ||
491  llvm::isa<IntegerType>(resType);
492 }
493 
494 LogicalResult gpu::AllReduceOp::verifyRegions() {
495  if (getBody().empty() != getOp().has_value())
496  return emitError("expected either an op attribute or a non-empty body");
497  if (!getBody().empty()) {
498  if (getBody().getNumArguments() != 2)
499  return emitError("expected two region arguments");
500  for (auto argument : getBody().getArguments()) {
501  if (argument.getType() != getType())
502  return emitError("incorrect region argument type");
503  }
504  unsigned yieldCount = 0;
505  for (Block &block : getBody()) {
506  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
507  if (yield.getNumOperands() != 1)
508  return emitError("expected one gpu.yield operand");
509  if (yield.getOperand(0).getType() != getType())
510  return emitError("incorrect gpu.yield type");
511  ++yieldCount;
512  }
513  }
514  if (yieldCount == 0)
515  return emitError("expected gpu.yield op in region");
516  } else {
517  gpu::AllReduceOperation opName = *getOp();
518  if (!verifyReduceOpAndType(opName, getType())) {
519  return emitError()
520  << '`' << gpu::stringifyAllReduceOperation(opName)
521  << "` accumulator is only compatible with Integer type";
522  }
523  }
524  return success();
525 }
526 
528  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
529  if (!launchOp)
530  return false;
531 
532  Region &body = launchOp.getBody();
533  assert(!body.empty() && "Invalid region");
534 
535  // Only convert ops in gpu::launch entry block for now.
536  return op->getBlock() == &body.front();
537 }
538 
539 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
540  if (!getUniform() && canMakeGroupOpUniform(*this)) {
541  setUniform(true);
542  return getResult();
543  }
544 
545  return nullptr;
546 }
547 
548 // TODO: Support optional custom attributes (without dialect prefix).
550  AllReduceOperationAttr &attr) {
551  StringRef enumStr;
552  if (!parser.parseOptionalKeyword(&enumStr)) {
553  std::optional<AllReduceOperation> op =
554  gpu::symbolizeAllReduceOperation(enumStr);
555  if (!op)
556  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
557  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
558  }
559  return success();
560 }
561 
562 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
563  AllReduceOperationAttr attr) {
564  if (attr)
565  attr.print(printer);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // SubgroupReduceOp
570 //===----------------------------------------------------------------------===//
571 
573  gpu::AllReduceOperation opName = getOp();
574  if (!verifyReduceOpAndType(opName, getType())) {
575  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
576  << "` accumulator is only compatible with Integer type";
577  }
578  return success();
579 }
580 
581 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
582  if (!getUniform() && canMakeGroupOpUniform(*this)) {
583  setUniform(true);
584  return getResult();
585  }
586 
587  return nullptr;
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // AsyncOpInterface
592 //===----------------------------------------------------------------------===//
593 
595  op->insertOperands(0, {token});
596  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
597  return;
598  auto attrName =
600  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
601 
602  // Async dependencies is the only variadic operand.
603  if (!sizeAttr)
604  return;
605 
606  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
607  ++sizes.front();
608  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // LaunchOp
613 //===----------------------------------------------------------------------===//
614 
615 void LaunchOp::build(OpBuilder &builder, OperationState &result,
616  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
617  Value getBlockSizeX, Value getBlockSizeY,
618  Value getBlockSizeZ, Value dynamicSharedMemorySize,
619  Type asyncTokenType, ValueRange asyncDependencies,
620  TypeRange workgroupAttributions,
621  TypeRange privateAttributions) {
622  // Add a WorkGroup attribution attribute. This attribute is required to
623  // identify private attributions in the list of block argguments.
624  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
625  builder.getI64IntegerAttr(workgroupAttributions.size()));
626 
627  // Add Op operands.
628  result.addOperands(asyncDependencies);
629  if (asyncTokenType)
630  result.types.push_back(builder.getType<AsyncTokenType>());
631 
632  // Add grid and block sizes as op operands, followed by the data operands.
633  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
634  getBlockSizeY, getBlockSizeZ});
635  if (dynamicSharedMemorySize)
636  result.addOperands(dynamicSharedMemorySize);
637 
638  // Create a kernel body region with kNumConfigRegionAttributes + N memory
639  // attributions, where the first kNumConfigRegionAttributes arguments have
640  // `index` type and the rest have the same types as the data operands.
641  Region *kernelRegion = result.addRegion();
642  Block *body = new Block();
643  // TODO: Allow passing in proper locations here.
644  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
645  body->addArgument(builder.getIndexType(), result.location);
646  // Add WorkGroup & Private attributions to the region arguments.
647  for (Type argTy : workgroupAttributions)
648  body->addArgument(argTy, result.location);
649  for (Type argTy : privateAttributions)
650  body->addArgument(argTy, result.location);
651  kernelRegion->push_back(body);
652  // Fill OperandSegmentSize Attribute.
653  SmallVector<int32_t, 8> segmentSizes(8, 1);
654  segmentSizes.front() = asyncDependencies.size();
655  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
656  result.addAttribute(getOperandSegmentSizeAttr(),
657  builder.getDenseI32ArrayAttr(segmentSizes));
658 }
659 
660 KernelDim3 LaunchOp::getBlockIds() {
661  assert(!getBody().empty() && "LaunchOp body must not be empty.");
662  auto args = getBody().getArguments();
663  return KernelDim3{args[0], args[1], args[2]};
664 }
665 
666 KernelDim3 LaunchOp::getThreadIds() {
667  assert(!getBody().empty() && "LaunchOp body must not be empty.");
668  auto args = getBody().getArguments();
669  return KernelDim3{args[3], args[4], args[5]};
670 }
671 
672 KernelDim3 LaunchOp::getGridSize() {
673  assert(!getBody().empty() && "LaunchOp body must not be empty.");
674  auto args = getBody().getArguments();
675  return KernelDim3{args[6], args[7], args[8]};
676 }
677 
678 KernelDim3 LaunchOp::getBlockSize() {
679  assert(!getBody().empty() && "LaunchOp body must not be empty.");
680  auto args = getBody().getArguments();
681  return KernelDim3{args[9], args[10], args[11]};
682 }
683 
684 KernelDim3 LaunchOp::getGridSizeOperandValues() {
685  auto operands = getOperands().drop_front(getAsyncDependencies().size());
686  return KernelDim3{operands[0], operands[1], operands[2]};
687 }
688 
689 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
690  auto operands = getOperands().drop_front(getAsyncDependencies().size());
691  return KernelDim3{operands[3], operands[4], operands[5]};
692 }
693 
694 LogicalResult LaunchOp::verifyRegions() {
695  // Kernel launch takes kNumConfigOperands leading operands for grid/block
696  // sizes and transforms them into kNumConfigRegionAttributes region arguments
697  // for block/thread identifiers and grid/block sizes.
698  if (!getBody().empty()) {
699  if (getBody().getNumArguments() <
700  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
701  return emitOpError("unexpected number of region arguments");
702  }
703 
704  // Verify Attributions Address Spaces.
705  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
706  GPUDialect::getWorkgroupAddressSpace())) ||
707  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
708  GPUDialect::getPrivateAddressSpace())))
709  return failure();
710 
711  // Block terminators without successors are expected to exit the kernel region
712  // and must be `gpu.terminator`.
713  for (Block &block : getBody()) {
714  if (block.empty())
715  continue;
716  if (block.back().getNumSuccessors() != 0)
717  continue;
718  if (!isa<gpu::TerminatorOp>(&block.back())) {
719  return block.back()
720  .emitError()
721  .append("expected '", gpu::TerminatorOp::getOperationName(),
722  "' or a terminator with successors")
723  .attachNote(getLoc())
724  .append("in '", LaunchOp::getOperationName(), "' body region");
725  }
726  }
727 
728  if (getNumResults() == 0 && getAsyncToken())
729  return emitOpError("needs to be named when async keyword is specified");
730 
731  return success();
732 }
733 
734 // Pretty-print the kernel grid/block size assignment as
735 // (%iter-x, %iter-y, %iter-z) in
736 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
737 // where %size-* and %iter-* will correspond to the body region arguments.
739  KernelDim3 operands, KernelDim3 ids) {
740  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
741  p << size.x << " = " << operands.x << ", ";
742  p << size.y << " = " << operands.y << ", ";
743  p << size.z << " = " << operands.z << ')';
744 }
745 
746 void LaunchOp::print(OpAsmPrinter &p) {
747  if (getAsyncToken()) {
748  p << " async";
749  if (!getAsyncDependencies().empty())
750  p << " [" << getAsyncDependencies() << ']';
751  }
752  // Print the launch configuration.
753  p << ' ' << getBlocksKeyword();
754  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
755  getBlockIds());
756  p << ' ' << getThreadsKeyword();
757  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
758  getThreadIds());
759  if (getDynamicSharedMemorySize())
760  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
761  << getDynamicSharedMemorySize();
762 
763  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
764  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
765 
766  p << ' ';
767 
768  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
769  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
770  LaunchOp::getOperandSegmentSizeAttr(),
771  getNumWorkgroupAttributionsAttrName()});
772 }
773 
774 // Parse the size assignment blocks for blocks and threads. These have the form
775 // (%region_arg, %region_arg, %region_arg) in
776 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
777 // where %region_arg are percent-identifiers for the region arguments to be
778 // introduced further (SSA defs), and %operand are percent-identifiers for the
779 // SSA value uses.
780 static ParseResult
785  assert(indices.size() == 3 && "space for three indices expected");
788  /*allowResultNumber=*/false) ||
789  parser.parseKeyword("in") || parser.parseLParen())
790  return failure();
791  std::move(args.begin(), args.end(), indices.begin());
792 
793  for (int i = 0; i < 3; ++i) {
794  if (i != 0 && parser.parseComma())
795  return failure();
796  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
797  parser.parseEqual() || parser.parseOperand(sizes[i]))
798  return failure();
799  }
800 
801  return parser.parseRParen();
802 }
803 
804 /// Parses a Launch operation.
805 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
806 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
807 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
808 /// memory-attribution
809 /// region attr-dict?
810 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
811 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
812  // Sizes of the grid and block.
814  sizes(LaunchOp::kNumConfigOperands);
816 
817  // Actual (data) operands passed to the kernel.
819 
820  // Region arguments to be created.
822  LaunchOp::kNumConfigRegionAttributes);
823  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
824 
825  // Parse optional async dependencies.
827  Type asyncTokenType;
828  if (failed(
829  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
830  parser.resolveOperands(asyncDependencies, asyncTokenType,
831  result.operands))
832  return failure();
833  if (parser.getNumResults() > 0)
834  result.types.push_back(asyncTokenType);
835 
836  // Parse the size assignment segments: the first segment assigns grid sizes
837  // and defines values for block identifiers; the second segment assigns block
838  // sizes and defines values for thread identifiers. In the region argument
839  // list, identifiers precede sizes, and block-related values precede
840  // thread-related values.
841  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
842  parseSizeAssignment(parser, sizesRef.take_front(3),
843  regionArgsRef.slice(6, 3),
844  regionArgsRef.slice(0, 3)) ||
845  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
846  parseSizeAssignment(parser, sizesRef.drop_front(3),
847  regionArgsRef.slice(9, 3),
848  regionArgsRef.slice(3, 3)) ||
849  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
850  result.operands))
851  return failure();
852 
853  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
854  bool hasDynamicSharedMemorySize = false;
855  if (!parser.parseOptionalKeyword(
856  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
857  hasDynamicSharedMemorySize = true;
858  if (parser.parseOperand(dynamicSharedMemorySize) ||
859  parser.resolveOperand(dynamicSharedMemorySize,
860  parser.getBuilder().getI32Type(),
861  result.operands))
862  return failure();
863  }
864 
865  // Create the region arguments, it has kNumConfigRegionAttributes arguments
866  // that correspond to block/thread identifiers and grid/block sizes, all
867  // having `index` type, a variadic number of WorkGroup Attributions and
868  // a variadic number of Private Attributions. The number of WorkGroup
869  // Attributions is stored in the attr with name:
870  // LaunchOp::getNumWorkgroupAttributionsAttrName().
871  Type index = parser.getBuilder().getIndexType();
873  LaunchOp::kNumConfigRegionAttributes, index);
874 
875  SmallVector<OpAsmParser::Argument> regionArguments;
876  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
878  arg.ssaName = std::get<0>(ssaValueAndType);
879  arg.type = std::get<1>(ssaValueAndType);
880  regionArguments.push_back(arg);
881  }
882 
883  Builder &builder = parser.getBuilder();
884  // Parse workgroup memory attributions.
885  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
886  regionArguments)))
887  return failure();
888 
889  // Store the number of operands we just parsed as the number of workgroup
890  // memory attributions.
891  unsigned numWorkgroupAttrs =
892  regionArguments.size() - LaunchOp::kNumConfigRegionAttributes;
893  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
894  builder.getI64IntegerAttr(numWorkgroupAttrs));
895 
896  // Parse private memory attributions.
897  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
898  regionArguments)))
899  return failure();
900 
901  // Introduce the body region and parse it. The region has
902  // kNumConfigRegionAttributes arguments that correspond to
903  // block/thread identifiers and grid/block sizes, all having `index` type.
904  Region *body = result.addRegion();
905  if (parser.parseRegion(*body, regionArguments) ||
906  parser.parseOptionalAttrDict(result.attributes))
907  return failure();
908 
909  SmallVector<int32_t, 8> segmentSizes(8, 1);
910  segmentSizes.front() = asyncDependencies.size();
911  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
912  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
913  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
914  return success();
915 }
916 
917 /// Simplify the gpu.launch when the range of a thread or block ID is
918 /// trivially known to be one.
919 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
922  PatternRewriter &rewriter) const override {
923  // If the range implies a single value for `id`, replace `id`'s uses by
924  // zero.
925  Value zero;
926  bool simplified = false;
927  auto constPropIdUses = [&](Value id, Value size) {
928  // Check if size is trivially one.
929  if (!matchPattern(size, m_One()))
930  return;
931  if (id.getUses().empty())
932  return;
933  if (!simplified) {
934  // Create a zero value the first time.
935  OpBuilder::InsertionGuard guard(rewriter);
936  rewriter.setInsertionPointToStart(&op.getBody().front());
937  zero =
938  rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
939  }
940  rewriter.replaceAllUsesWith(id, zero);
941  simplified = true;
942  };
943  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
944  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
945  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
946  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
947  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
948  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
949 
950  return success(simplified);
951  }
952 };
953 
954 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
955  MLIRContext *context) {
956  rewrites.add<FoldLaunchArguments>(context);
957 }
958 
959 /// Adds a new block argument that corresponds to buffers located in
960 /// workgroup memory.
961 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
962  auto attrName = getNumWorkgroupAttributionsAttrName();
963  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
964  (*this)->setAttr(attrName,
965  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
966  return getBody().insertArgument(
967  LaunchOp::kNumConfigRegionAttributes + attr.getInt(), type, loc);
968 }
969 
970 /// Adds a new block argument that corresponds to buffers located in
971 /// private memory.
972 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
973  // Buffers on the private memory always come after buffers on the workgroup
974  // memory.
975  return getBody().addArgument(type, loc);
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // LaunchFuncOp
980 //===----------------------------------------------------------------------===//
981 
982 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
983  GPUFuncOp kernelFunc, KernelDim3 gridSize,
984  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
985  ValueRange kernelOperands, Type asyncTokenType,
986  ValueRange asyncDependencies) {
987  result.addOperands(asyncDependencies);
988  if (asyncTokenType)
989  result.types.push_back(builder.getType<AsyncTokenType>());
990 
991  // Add grid and block sizes as op operands, followed by the data operands.
992  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
993  getBlockSize.y, getBlockSize.z});
994  if (dynamicSharedMemorySize)
995  result.addOperands(dynamicSharedMemorySize);
996  result.addOperands(kernelOperands);
997  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
998  auto kernelSymbol =
999  SymbolRefAttr::get(kernelModule.getNameAttr(),
1000  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1001 
1002  Properties &prop = result.getOrAddProperties<Properties>();
1003  prop.kernel = kernelSymbol;
1004  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1005  // Initialize the segment sizes to 1.
1006  for (auto &sz : prop.operandSegmentSizes)
1007  sz = 1;
1008  prop.operandSegmentSizes[0] = asyncDependencies.size();
1009  prop.operandSegmentSizes[segmentSizesLen - 3] =
1010  dynamicSharedMemorySize ? 1 : 0;
1011  prop.operandSegmentSizes[segmentSizesLen - 2] =
1012  static_cast<int32_t>(kernelOperands.size());
1013  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1014 }
1015 
1016 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1017  SymbolRefAttr kernel, KernelDim3 gridSize,
1018  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1019  ValueRange kernelOperands, Value asyncObject) {
1020  // Add grid and block sizes as op operands, followed by the data operands.
1021  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1022  getBlockSize.y, getBlockSize.z});
1023  if (dynamicSharedMemorySize)
1024  result.addOperands(dynamicSharedMemorySize);
1025  result.addOperands(kernelOperands);
1026  if (asyncObject)
1027  result.addOperands(asyncObject);
1028  Properties &prop = result.getOrAddProperties<Properties>();
1029  prop.kernel = kernel;
1030  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1031  // Initialize the segment sizes to 1.
1032  for (auto &sz : prop.operandSegmentSizes)
1033  sz = 1;
1034  prop.operandSegmentSizes[0] = 0;
1035  prop.operandSegmentSizes[segmentSizesLen - 3] =
1036  dynamicSharedMemorySize ? 1 : 0;
1037  prop.operandSegmentSizes[segmentSizesLen - 2] =
1038  static_cast<int32_t>(kernelOperands.size());
1039  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1040 }
1041 
1042 StringAttr LaunchFuncOp::getKernelModuleName() {
1043  return getKernel().getRootReference();
1044 }
1045 
1046 StringAttr LaunchFuncOp::getKernelName() {
1047  return getKernel().getLeafReference();
1048 }
1049 
1050 unsigned LaunchFuncOp::getNumKernelOperands() {
1051  return getKernelOperands().size();
1052 }
1053 
1054 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1055  return getKernelOperands()[i];
1056 }
1057 
1058 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1059  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1060  return KernelDim3{operands[0], operands[1], operands[2]};
1061 }
1062 
1063 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1064  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1065  return KernelDim3{operands[3], operands[4], operands[5]};
1066 }
1067 
1069  auto module = (*this)->getParentOfType<ModuleOp>();
1070  if (!module)
1071  return emitOpError("expected to belong to a module");
1072 
1073  if (!module->getAttrOfType<UnitAttr>(
1074  GPUDialect::getContainerModuleAttrName()))
1075  return emitOpError("expected the closest surrounding module to have the '" +
1076  GPUDialect::getContainerModuleAttrName() +
1077  "' attribute");
1078 
1079  return success();
1080 }
1081 
1083  if (succeeded(parser.parseOptionalColon())) {
1084  if (parser.parseType(dimTy))
1085  return failure();
1086  } else {
1087  dimTy = IndexType::get(parser.getContext());
1088  }
1089  return success();
1090 }
1091 
1092 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op,
1093  Type dimTy) {
1094  if (!dimTy.isIndex())
1095  printer << ": " << dimTy;
1096 }
1097 
1099  OpAsmParser &parser,
1101  SmallVectorImpl<Type> &argTypes) {
1102  if (parser.parseOptionalKeyword("args"))
1103  return success();
1104 
1105  auto parseElement = [&]() -> ParseResult {
1106  return failure(parser.parseOperand(argNames.emplace_back()) ||
1107  parser.parseColonType(argTypes.emplace_back()));
1108  };
1109 
1111  parseElement, " in argument list");
1112 }
1113 
1115  OperandRange operands, TypeRange types) {
1116  if (operands.empty())
1117  return;
1118  printer << "args(";
1119  llvm::interleaveComma(llvm::zip(operands, types), printer,
1120  [&](const auto &pair) {
1121  printer.printOperand(std::get<0>(pair));
1122  printer << " : ";
1123  printer.printType(std::get<1>(pair));
1124  });
1125  printer << ")";
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // ShuffleOp
1130 //===----------------------------------------------------------------------===//
1131 
1132 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1133  int32_t offset, int32_t width, ShuffleMode mode) {
1134  build(builder, result, value,
1135  builder.create<arith::ConstantOp>(result.location,
1136  builder.getI32IntegerAttr(offset)),
1137  builder.create<arith::ConstantOp>(result.location,
1138  builder.getI32IntegerAttr(width)),
1139  mode);
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // GPUFuncOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 /// Adds a new block argument that corresponds to buffers located in
1147 /// workgroup memory.
1148 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1149  auto attrName = getNumWorkgroupAttributionsAttrName();
1150  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1151  (*this)->setAttr(attrName,
1152  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1153  return getBody().insertArgument(
1154  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1155 }
1156 
1157 /// Adds a new block argument that corresponds to buffers located in
1158 /// private memory.
1159 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1160  // Buffers on the private memory always come after buffers on the workgroup
1161  // memory.
1162  return getBody().addArgument(type, loc);
1163 }
1164 
1165 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1166  StringRef name, FunctionType type,
1167  TypeRange workgroupAttributions,
1168  TypeRange privateAttributions,
1169  ArrayRef<NamedAttribute> attrs) {
1171  builder.getStringAttr(name));
1172  result.addAttribute(getFunctionTypeAttrName(result.name),
1173  TypeAttr::get(type));
1174  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1175  builder.getI64IntegerAttr(workgroupAttributions.size()));
1176  result.addAttributes(attrs);
1177  Region *body = result.addRegion();
1178  Block *entryBlock = new Block;
1179 
1180  // TODO: Allow passing in proper locations here.
1181  for (Type argTy : type.getInputs())
1182  entryBlock->addArgument(argTy, result.location);
1183  for (Type argTy : workgroupAttributions)
1184  entryBlock->addArgument(argTy, result.location);
1185  for (Type argTy : privateAttributions)
1186  entryBlock->addArgument(argTy, result.location);
1187 
1188  body->getBlocks().push_back(entryBlock);
1189 }
1190 
1191 /// Parses a GPU function memory attribution.
1192 ///
1193 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1194 /// (`private` `(` ssa-id-and-type-list `)`)?
1195 ///
1196 /// Note that this function parses only one of the two similar parts, with the
1197 /// keyword provided as argument.
1198 static ParseResult
1199 parseAttributions(OpAsmParser &parser, StringRef keyword,
1201  Attribute &attributionAttrs) {
1202  // If we could not parse the keyword, just assume empty list and succeed.
1203  if (failed(parser.parseOptionalKeyword(keyword)))
1204  return success();
1205 
1206  size_t existingArgs = args.size();
1207  ParseResult result =
1209  /*allowType=*/true, /*allowAttrs=*/true);
1210  if (failed(result))
1211  return result;
1212 
1213  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1214  [](const OpAsmParser::Argument &arg) -> bool {
1215  return arg.attrs && !arg.attrs.empty();
1216  });
1217  if (!hadAttrs) {
1218  attributionAttrs = nullptr;
1219  return result;
1220  }
1221 
1222  Builder &builder = parser.getBuilder();
1223  SmallVector<Attribute> attributionAttrsVec;
1224  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1225  if (!argument.attrs)
1226  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1227  else
1228  attributionAttrsVec.push_back(argument.attrs);
1229  }
1230  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1231  return result;
1232 }
1233 
1234 /// Parses a GPU function.
1235 ///
1236 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1237 /// (`->` function-result-list)? memory-attribution `kernel`?
1238 /// function-attributes? region
1239 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1241  SmallVector<DictionaryAttr> resultAttrs;
1242  SmallVector<Type> resultTypes;
1243  bool isVariadic;
1244 
1245  // Parse the function name.
1246  StringAttr nameAttr;
1247  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1248  result.attributes))
1249  return failure();
1250 
1251  auto signatureLocation = parser.getCurrentLocation();
1253  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1254  resultAttrs)))
1255  return failure();
1256 
1257  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1258  return parser.emitError(signatureLocation)
1259  << "gpu.func requires named arguments";
1260 
1261  // Construct the function type. More types will be added to the region, but
1262  // not to the function type.
1263  Builder &builder = parser.getBuilder();
1264 
1265  SmallVector<Type> argTypes;
1266  for (auto &arg : entryArgs)
1267  argTypes.push_back(arg.type);
1268  auto type = builder.getFunctionType(argTypes, resultTypes);
1269  result.addAttribute(getFunctionTypeAttrName(result.name),
1270  TypeAttr::get(type));
1271 
1273  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1274  getResAttrsAttrName(result.name));
1275 
1276  Attribute workgroupAttributionAttrs;
1277  // Parse workgroup memory attributions.
1278  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1279  entryArgs, workgroupAttributionAttrs)))
1280  return failure();
1281 
1282  // Store the number of operands we just parsed as the number of workgroup
1283  // memory attributions.
1284  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1285  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1286  builder.getI64IntegerAttr(numWorkgroupAttrs));
1287  if (workgroupAttributionAttrs)
1288  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1289  workgroupAttributionAttrs);
1290 
1291  Attribute privateAttributionAttrs;
1292  // Parse private memory attributions.
1293  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1294  entryArgs, privateAttributionAttrs)))
1295  return failure();
1296  if (privateAttributionAttrs)
1297  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1298  privateAttributionAttrs);
1299 
1300  // Parse the kernel attribute if present.
1301  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1302  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1303  builder.getUnitAttr());
1304 
1305  // Parse attributes.
1307  return failure();
1308 
1309  // Parse the region. If no argument names were provided, take all names
1310  // (including those of attributions) from the entry block.
1311  auto *body = result.addRegion();
1312  return parser.parseRegion(*body, entryArgs);
1313 }
1314 
1315 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1316  ArrayRef<BlockArgument> values,
1317  ArrayAttr attributes) {
1318  if (values.empty())
1319  return;
1320 
1321  p << ' ' << keyword << '(';
1322  llvm::interleaveComma(
1323  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1324  BlockArgument v = pair.value();
1325  p << v << " : " << v.getType();
1326 
1327  size_t attributionIndex = pair.index();
1328  DictionaryAttr attrs;
1329  if (attributes && attributionIndex < attributes.size())
1330  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1331  if (attrs)
1332  p.printOptionalAttrDict(attrs.getValue());
1333  });
1334  p << ')';
1335 }
1336 
1337 void GPUFuncOp::print(OpAsmPrinter &p) {
1338  p << ' ';
1339  p.printSymbolName(getName());
1340 
1341  FunctionType type = getFunctionType();
1342  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1343  /*isVariadic=*/false,
1344  type.getResults());
1345 
1346  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1347  getWorkgroupAttribAttrs().value_or(nullptr));
1348  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1349  getPrivateAttribAttrs().value_or(nullptr));
1350  if (isKernel())
1351  p << ' ' << getKernelKeyword();
1352 
1354  p, *this,
1355  {getNumWorkgroupAttributionsAttrName(),
1356  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1357  getArgAttrsAttrName(), getResAttrsAttrName(),
1358  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1359  p << ' ';
1360  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1361 }
1362 
1363 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1364  StringAttr attrName) {
1365  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1366  if (!allAttrs || index >= allAttrs.size())
1367  return DictionaryAttr();
1368  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1369 }
1370 
1371 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1372  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1373 }
1374 
1375 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1376  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1377 }
1378 
1379 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1380  DictionaryAttr value, StringAttr attrName) {
1381  MLIRContext *ctx = op.getContext();
1382  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1383  SmallVector<Attribute> elements;
1384  if (allAttrs)
1385  elements.append(allAttrs.begin(), allAttrs.end());
1386  while (elements.size() <= index)
1387  elements.push_back(DictionaryAttr::get(ctx));
1388  if (!value)
1389  elements[index] = DictionaryAttr::get(ctx);
1390  else
1391  elements[index] = value;
1392  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1393  op->setAttr(attrName, newValue);
1394 }
1395 
1396 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1397  DictionaryAttr value) {
1398  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1399 }
1400 
1401 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1402  DictionaryAttr value) {
1403  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1404 }
1405 
1406 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1407  StringAttr name, StringAttr attrsName) {
1408  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1409  if (!dict)
1410  return Attribute();
1411  return dict.get(name);
1412 }
1413 
1414 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1415  StringAttr name) {
1416  assert(index < getNumWorkgroupAttributions() &&
1417  "index must map to a workgroup attribution");
1418  return getAttributionAttr(*this, index, name,
1419  getWorkgroupAttribAttrsAttrName());
1420 }
1421 
1422 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1423  StringAttr name) {
1424  assert(index < getNumPrivateAttributions() &&
1425  "index must map to a private attribution");
1426  return getAttributionAttr(*this, index, name,
1427  getPrivateAttribAttrsAttrName());
1428 }
1429 
1430 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1431  Attribute value, StringAttr attrsName) {
1432  MLIRContext *ctx = op.getContext();
1434  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1435  if (oldDict)
1436  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1437 
1438  bool found = false;
1439  bool mustSort = true;
1440  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1441  if (elems[i].getName() == name) {
1442  found = true;
1443  if (!value) {
1444  std::swap(elems[i], elems[elems.size() - 1]);
1445  elems.pop_back();
1446  } else {
1447  mustSort = false;
1448  elems[i] = NamedAttribute(elems[i].getName(), value);
1449  }
1450  break;
1451  }
1452  }
1453  if (!found) {
1454  if (!value)
1455  return;
1456  elems.emplace_back(name, value);
1457  }
1458  if (mustSort) {
1459  DictionaryAttr::sortInPlace(elems);
1460  }
1461  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1462  setAttributionAttrs(op, index, newDict, attrsName);
1463 }
1464 
1465 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1466  Attribute value) {
1467  assert(index < getNumWorkgroupAttributions() &&
1468  "index must map to a workgroup attribution");
1469  setAttributionAttr(*this, index, name, value,
1470  getWorkgroupAttribAttrsAttrName());
1471 }
1472 
1473 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1474  Attribute value) {
1475  assert(index < getNumPrivateAttributions() &&
1476  "index must map to a private attribution");
1477  setAttributionAttr(*this, index, name, value,
1478  getPrivateAttribAttrsAttrName());
1479 }
1480 
1481 LogicalResult GPUFuncOp::verifyType() {
1482  if (isKernel() && getFunctionType().getNumResults() != 0)
1483  return emitOpError() << "expected void return type for kernel function";
1484 
1485  return success();
1486 }
1487 
1488 /// Verifies the body of the function.
1489 LogicalResult GPUFuncOp::verifyBody() {
1490  if (empty())
1491  return emitOpError() << "expected body with at least one block";
1492  unsigned numFuncArguments = getNumArguments();
1493  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1494  unsigned numBlockArguments = front().getNumArguments();
1495  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1496  return emitOpError() << "expected at least "
1497  << numFuncArguments + numWorkgroupAttributions
1498  << " arguments to body region";
1499 
1500  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1501  for (unsigned i = 0; i < numFuncArguments; ++i) {
1502  Type blockArgType = front().getArgument(i).getType();
1503  if (funcArgTypes[i] != blockArgType)
1504  return emitOpError() << "expected body region argument #" << i
1505  << " to be of type " << funcArgTypes[i] << ", got "
1506  << blockArgType;
1507  }
1508 
1509  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1510  GPUDialect::getWorkgroupAddressSpace())) ||
1511  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1512  GPUDialect::getPrivateAddressSpace())))
1513  return failure();
1514 
1515  return success();
1516 }
1517 
1518 static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
1519  StringRef attrName) {
1520  auto maybeAttr = op->getAttr(attrName);
1521  if (!maybeAttr)
1522  return success();
1523  auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
1524  if (!array)
1525  return op.emitOpError(attrName + " must be a dense i32 array");
1526  if (array.size() != 3)
1527  return op.emitOpError(attrName + " must contain exactly 3 elements");
1528  return success();
1529 }
1530 
1532  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName())))
1533  return failure();
1534  if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName())))
1535  return failure();
1536  return success();
1537 }
1538 
1539 //===----------------------------------------------------------------------===//
1540 // ReturnOp
1541 //===----------------------------------------------------------------------===//
1542 
1544  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1545 
1546  FunctionType funType = function.getFunctionType();
1547 
1548  if (funType.getNumResults() != getOperands().size())
1549  return emitOpError()
1550  .append("expected ", funType.getNumResults(), " result operands")
1551  .attachNote(function.getLoc())
1552  .append("return type declared here");
1553 
1554  for (const auto &pair : llvm::enumerate(
1555  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1556  auto [type, operand] = pair.value();
1557  if (type != operand.getType())
1558  return emitOpError() << "unexpected type `" << operand.getType()
1559  << "' for operand #" << pair.index();
1560  }
1561  return success();
1562 }
1563 
1564 //===----------------------------------------------------------------------===//
1565 // GPUModuleOp
1566 //===----------------------------------------------------------------------===//
1567 
1568 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1569  StringRef name, ArrayAttr targets) {
1570  ensureTerminator(*result.addRegion(), builder, result.location);
1571  result.attributes.push_back(builder.getNamedAttr(
1573 
1574  if (targets)
1575  result.getOrAddProperties<Properties>().targets = targets;
1576 }
1577 
1578 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1579  StringRef name, ArrayRef<Attribute> targets) {
1580  build(builder, result, name,
1581  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets));
1582 }
1583 
1584 ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1585  StringAttr nameAttr;
1586  ArrayAttr targetsAttr;
1587 
1589  result.attributes))
1590  return failure();
1591 
1592  // Parse the optional array of target attributes.
1593  OptionalParseResult targetsAttrResult =
1594  parser.parseOptionalAttribute(targetsAttr, Type{});
1595  if (targetsAttrResult.has_value()) {
1596  if (failed(*targetsAttrResult)) {
1597  return failure();
1598  }
1599  result.getOrAddProperties<Properties>().targets = targetsAttr;
1600  }
1601 
1602  // If module attributes are present, parse them.
1603  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1604  return failure();
1605 
1606  // Parse the module body.
1607  auto *body = result.addRegion();
1608  if (parser.parseRegion(*body, {}))
1609  return failure();
1610 
1611  // Ensure that this module has a valid terminator.
1612  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1613  return success();
1614 }
1615 
1617  p << ' ';
1618  p.printSymbolName(getName());
1619 
1620  if (Attribute attr = getTargetsAttr()) {
1621  p << ' ';
1622  p.printAttribute(attr);
1623  p << ' ';
1624  }
1625 
1627  (*this)->getAttrs(),
1628  {mlir::SymbolTable::getSymbolAttrName(), getTargetsAttrName()});
1629  p << ' ';
1630  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1631  /*printBlockTerminators=*/false);
1632 }
1633 
1634 bool GPUModuleOp::hasTarget(Attribute target) {
1635  if (ArrayAttr targets = getTargetsAttr())
1636  return llvm::count(targets.getValue(), target);
1637  return false;
1638 }
1639 
1640 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1641  ArrayAttr &targetsAttr = getProperties().targets;
1642  SmallVector<Attribute> targetsVector(targets);
1643  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1644 }
1645 
1646 //===----------------------------------------------------------------------===//
1647 // GPUBinaryOp
1648 //===----------------------------------------------------------------------===//
1649 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1650  Attribute offloadingHandler, ArrayAttr objects) {
1651  auto &properties = result.getOrAddProperties<Properties>();
1652  result.attributes.push_back(builder.getNamedAttr(
1653  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1654  properties.objects = objects;
1655  if (offloadingHandler)
1656  properties.offloadingHandler = offloadingHandler;
1657  else
1658  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1659 }
1660 
1661 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1662  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1663  build(builder, result, name, offloadingHandler,
1664  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1665 }
1666 
1668  Attribute &offloadingHandler) {
1669  if (succeeded(parser.parseOptionalLess())) {
1670  if (parser.parseAttribute(offloadingHandler))
1671  return failure();
1672  if (parser.parseGreater())
1673  return failure();
1674  }
1675  if (!offloadingHandler)
1676  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1677  return success();
1678 }
1679 
1681  Attribute offloadingHandler) {
1682  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1683  printer << '<' << offloadingHandler << '>';
1684 }
1685 
1686 //===----------------------------------------------------------------------===//
1687 // GPUMemcpyOp
1688 //===----------------------------------------------------------------------===//
1689 
1690 LogicalResult MemcpyOp::verify() {
1691  auto srcType = getSrc().getType();
1692  auto dstType = getDst().getType();
1693 
1694  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1695  return emitOpError("arguments have incompatible element type");
1696 
1697  if (failed(verifyCompatibleShape(srcType, dstType)))
1698  return emitOpError("arguments have incompatible shape");
1699 
1700  return success();
1701 }
1702 
1703 namespace {
1704 
1705 /// Erases a common case of copy ops where a destination value is used only by
1706 /// the copy op, alloc and dealloc ops.
1707 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1708  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1709 
1710  LogicalResult matchAndRewrite(MemcpyOp op,
1711  PatternRewriter &rewriter) const override {
1712  Value dest = op.getDst();
1713  Operation *destDefOp = dest.getDefiningOp();
1714  // `dest` must be defined by an op having Allocate memory effect in order to
1715  // perform the folding.
1716  if (!destDefOp ||
1717  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1718  return failure();
1719  // We can erase `op` iff `dest` has no other use apart from its
1720  // use by `op` and dealloc ops.
1721  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1722  return user != op &&
1723  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1724  }))
1725  return failure();
1726  // We can perform the folding if and only if op has a single async
1727  // dependency and produces an async token as result, or if it does not have
1728  // any async dependency and does not produce any async token result.
1729  if (op.getAsyncDependencies().size() > 1 ||
1730  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1731  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1732  return failure();
1733  rewriter.replaceOp(op, op.getAsyncDependencies());
1734  return success();
1735  }
1736 };
1737 
1738 } // end anonymous namespace
1739 
1740 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1741  MLIRContext *context) {
1742  results.add<EraseTrivialCopyOp>(context);
1743 }
1744 
1745 //===----------------------------------------------------------------------===//
1746 // GPU_SubgroupMmaLoadMatrixOp
1747 //===----------------------------------------------------------------------===//
1748 
1749 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1750  auto srcType = getSrcMemref().getType();
1751  auto resType = getRes().getType();
1752  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1753  auto operand = resMatrixType.getOperand();
1754  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1755 
1756  if (!isLastMemrefDimUnitStride(srcMemrefType))
1757  return emitError(
1758  "expected source memref most minor dim must have unit stride");
1759 
1760  if (!operand.equals("AOp") && !operand.equals("BOp") &&
1761  !operand.equals("COp"))
1762  return emitError("only AOp, BOp and COp can be loaded");
1763 
1764  return success();
1765 }
1766 
1767 //===----------------------------------------------------------------------===//
1768 // GPU_SubgroupMmaStoreMatrixOp
1769 //===----------------------------------------------------------------------===//
1770 
1771 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1772  auto srcType = getSrc().getType();
1773  auto dstType = getDstMemref().getType();
1774  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1775  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1776 
1777  if (!isLastMemrefDimUnitStride(dstMemrefType))
1778  return emitError(
1779  "expected destination memref most minor dim must have unit stride");
1780 
1781  if (!srcMatrixType.getOperand().equals("COp"))
1782  return emitError(
1783  "expected the operand matrix being stored to have 'COp' operand type");
1784 
1785  return success();
1786 }
1787 
1788 //===----------------------------------------------------------------------===//
1789 // GPU_SubgroupMmaComputeOp
1790 //===----------------------------------------------------------------------===//
1791 
1792 LogicalResult SubgroupMmaComputeOp::verify() {
1793  enum OperandMap { A, B, C };
1794  SmallVector<MMAMatrixType, 3> opTypes;
1795  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1796  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1797  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1798 
1799  if (!opTypes[A].getOperand().equals("AOp") ||
1800  !opTypes[B].getOperand().equals("BOp") ||
1801  !opTypes[C].getOperand().equals("COp"))
1802  return emitError("operands must be in the order AOp, BOp, COp");
1803 
1804  ArrayRef<int64_t> aShape, bShape, cShape;
1805  aShape = opTypes[A].getShape();
1806  bShape = opTypes[B].getShape();
1807  cShape = opTypes[C].getShape();
1808 
1809  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1810  bShape[1] != cShape[1])
1811  return emitError("operand shapes do not satisfy matmul constraints");
1812 
1813  return success();
1814 }
1815 
1816 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1817  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1818  return memref::foldMemRefCast(*this);
1819 }
1820 
1821 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1822  SmallVectorImpl<::mlir::OpFoldResult> &results) {
1823  return memref::foldMemRefCast(*this);
1824 }
1825 
1826 //===----------------------------------------------------------------------===//
1827 // GPU_WaitOp
1828 //===----------------------------------------------------------------------===//
1829 
1830 namespace {
1831 
1832 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1833 /// %t = gpu.wait async [] // No async dependencies.
1834 /// ... gpu.wait ... [%t, ...] // %t can be removed.
1835 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1836 public:
1837  using OpRewritePattern::OpRewritePattern;
1838 
1839  LogicalResult matchAndRewrite(WaitOp op,
1840  PatternRewriter &rewriter) const final {
1841  auto predicate = [](Value value) {
1842  auto waitOp = value.getDefiningOp<WaitOp>();
1843  return waitOp && waitOp->getNumOperands() == 0;
1844  };
1845  if (llvm::none_of(op.getAsyncDependencies(), predicate))
1846  return failure();
1847  SmallVector<Value> validOperands;
1848  for (Value operand : op->getOperands()) {
1849  if (predicate(operand))
1850  continue;
1851  validOperands.push_back(operand);
1852  }
1853  rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); });
1854  return success();
1855  }
1856 };
1857 
1858 /// Simplify trivial gpu.wait ops for the following patterns.
1859 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
1860 /// dependencies).
1861 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
1862 /// %t0.
1863 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
1864 /// dependencies nor return any token.
1865 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
1866 public:
1867  using OpRewritePattern::OpRewritePattern;
1868 
1869  LogicalResult matchAndRewrite(WaitOp op,
1870  PatternRewriter &rewriter) const final {
1871  // Erase gpu.wait ops that neither have any async dependencies nor return
1872  // any async token.
1873  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
1874  rewriter.eraseOp(op);
1875  return success();
1876  }
1877  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
1878  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
1879  op.getAsyncToken()) {
1880  rewriter.replaceOp(op, op.getAsyncDependencies());
1881  return success();
1882  }
1883  // Erase %t = gpu.wait async ... ops, where %t has no uses.
1884  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
1885  rewriter.eraseOp(op);
1886  return success();
1887  }
1888  return failure();
1889  }
1890 };
1891 
1892 } // end anonymous namespace
1893 
1894 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
1895  MLIRContext *context) {
1896  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // GPU_AllocOp
1901 //===----------------------------------------------------------------------===//
1902 
1903 LogicalResult AllocOp::verify() {
1904  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
1905 
1906  if (static_cast<int64_t>(getDynamicSizes().size()) !=
1907  memRefType.getNumDynamicDims())
1908  return emitOpError("dimension operand count does not equal memref "
1909  "dynamic dimension count");
1910 
1911  unsigned numSymbols = 0;
1912  if (!memRefType.getLayout().isIdentity())
1913  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
1914  if (getSymbolOperands().size() != numSymbols) {
1915  return emitOpError(
1916  "symbol operand count does not equal memref symbol count");
1917  }
1918 
1919  return success();
1920 }
1921 
1922 namespace {
1923 
1924 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
1925 /// `memref::AllocOp`.
1926 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
1927  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
1928 
1929  LogicalResult matchAndRewrite(memref::DimOp dimOp,
1930  PatternRewriter &rewriter) const override {
1931  std::optional<int64_t> index = dimOp.getConstantIndex();
1932  if (!index)
1933  return failure();
1934 
1935  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
1936  if (!memrefType || !memrefType.isDynamicDim(index.value()))
1937  return failure();
1938 
1939  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
1940  if (!alloc)
1941  return failure();
1942 
1943  Value substituteOp = *(alloc.getDynamicSizes().begin() +
1944  memrefType.getDynamicDimIndex(index.value()));
1945  rewriter.replaceOp(dimOp, substituteOp);
1946  return success();
1947  }
1948 };
1949 
1950 } // namespace
1951 
1952 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1953  MLIRContext *context) {
1954  results.add<SimplifyDimOfAllocOp>(context);
1955 }
1956 
1957 //===----------------------------------------------------------------------===//
1958 // GPU object attribute
1959 //===----------------------------------------------------------------------===//
1960 
1961 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1962  Attribute target, CompilationTarget format,
1963  StringAttr object, DictionaryAttr properties) {
1964  if (!target)
1965  return emitError() << "the target attribute cannot be null";
1966  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
1967  return success();
1968  return emitError() << "the target attribute must implement or promise the "
1969  "`gpu::TargetAttrInterface`";
1970 }
1971 
1972 namespace {
1973 LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format,
1974  StringAttr &object) {
1975  std::optional<CompilationTarget> formatResult;
1976  StringRef enumKeyword;
1977  auto loc = odsParser.getCurrentLocation();
1978  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
1979  formatResult = CompilationTarget::Fatbin;
1980  if (!formatResult &&
1981  (formatResult =
1982  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
1983  odsParser.parseEqual())
1984  return odsParser.emitError(loc, "expected an equal sign");
1985  if (!formatResult)
1986  return odsParser.emitError(loc, "expected keyword for GPU object format");
1987  FailureOr<StringAttr> objectResult =
1988  FieldParser<StringAttr>::parse(odsParser);
1989  if (failed(objectResult))
1990  return odsParser.emitError(odsParser.getCurrentLocation(),
1991  "failed to parse GPU_ObjectAttr parameter "
1992  "'object' which is to be a `StringAttr`");
1993  format = *formatResult;
1994  object = *objectResult;
1995  return success();
1996 }
1997 
1998 void printObject(AsmPrinter &odsParser, CompilationTarget format,
1999  StringAttr object) {
2000  if (format != CompilationTarget::Fatbin)
2001  odsParser << stringifyEnum(format) << " = ";
2002  odsParser << object;
2003 }
2004 } // namespace
2005 
2006 //===----------------------------------------------------------------------===//
2007 // GPU select object attribute
2008 //===----------------------------------------------------------------------===//
2009 
2010 LogicalResult
2011 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2012  Attribute target) {
2013  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2014  if (target) {
2015  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2016  if (intAttr.getInt() < 0) {
2017  return emitError() << "the object index must be positive";
2018  }
2019  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2020  return emitError()
2021  << "the target attribute must be a GPU Target attribute";
2022  }
2023  }
2024  return success();
2025 }
2026 
2027 //===----------------------------------------------------------------------===//
2028 // GPU target options
2029 //===----------------------------------------------------------------------===//
2030 
2031 TargetOptions::TargetOptions(
2032  StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2033  StringRef cmdOptions, CompilationTarget compilationTarget,
2034  function_ref<SymbolTable *()> getSymbolTableCallback)
2035  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2036  cmdOptions, compilationTarget, getSymbolTableCallback) {}
2037 
2038 TargetOptions::TargetOptions(
2039  TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
2040  StringRef cmdOptions, CompilationTarget compilationTarget,
2041  function_ref<SymbolTable *()> getSymbolTableCallback)
2042  : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
2043  cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2044  getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2045 
2046 TypeID TargetOptions::getTypeID() const { return typeID; }
2047 
2048 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2049 
2050 ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; }
2051 
2052 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2053 
2054 SymbolTable *TargetOptions::getSymbolTable() const {
2055  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2056 }
2057 
2058 CompilationTarget TargetOptions::getCompilationTarget() const {
2059  return compilationTarget;
2060 }
2061 
2062 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2063  return CompilationTarget::Fatbin;
2064 }
2065 
2066 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2067 TargetOptions::tokenizeCmdOptions() const {
2068  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2069  llvm::StringSaver stringSaver(options.first);
2070  StringRef opts = cmdOptions;
2071  // For a correct tokenization of the command line options `opts` must be
2072  // unquoted, otherwise the tokenization function returns a single string: the
2073  // unquoted `cmdOptions` -which is not the desired behavior.
2074  // Remove any quotes if they are at the beginning and end of the string:
2075  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2076  opts.consume_front("\""), opts.consume_back("\"");
2077  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2078  opts.consume_front("'"), opts.consume_back("'");
2079 #ifdef _WIN32
2080  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2081  /*MarkEOLs=*/false);
2082 #else
2083  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2084  /*MarkEOLs=*/false);
2085 #endif // _WIN32
2086  return options;
2087 }
2088 
2090 
2091 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2092 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2093 
2094 #define GET_ATTRDEF_CLASSES
2095 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2096 
2097 #define GET_OP_CLASSES
2098 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2099 
2100 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:398
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:549
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:414
GPUMemorySpace
GPU memory space identifiers.
Definition: GPUDialect.cpp:168
@ kGlobalMemorySpace
Global memory space identifier.
Definition: GPUDialect.cpp:173
@ kSharedMemorySpace
Shared memory space identifier.
Definition: GPUDialect.cpp:176
@ kGenericMemorySpace
Generic memory space identifier.
Definition: GPUDialect.cpp:170
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:781
static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, StringRef attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:448
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:527
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:214
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy)
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:562
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:437
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static bool verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:486
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:738
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:460
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:110
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:100
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:43
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:45
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:212
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:255
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:59
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:77
bool isIndex() const
Definition: Types.cpp:56
bool isF32() const
Definition: Types.cpp:51
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:89
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static ConcreteT get(MLIRContext *ctx, Args... args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:127
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:116
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:135
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:139
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:137
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:122
static LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:146
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:129
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:594
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
Definition: GPUDialect.cpp:919
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
Definition: GPUDialect.cpp:921
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:37