MLIR  22.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/SymbolTable.h"
30 #include "mlir/IR/TypeUtilities.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/StringSaver.h"
42 #include <cassert>
43 #include <numeric>
44 
45 using namespace mlir;
46 using namespace mlir::gpu;
47 
48 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // GPU Device Mapping Attributes
52 //===----------------------------------------------------------------------===//
53 
54 int64_t GPUBlockMappingAttr::getMappingId() const {
55  return static_cast<int64_t>(getBlock());
56 }
57 
58 bool GPUBlockMappingAttr::isLinearMapping() const {
59  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
60 }
61 
62 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
63  return isLinearMapping()
64  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
65  : getMappingId();
66 }
67 
68 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
69  return static_cast<int64_t>(getWarpgroup());
70 }
71 
72 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
73  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
74 }
75 
76 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
77  return isLinearMapping()
78  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
79  : getMappingId();
80 }
81 
82 int64_t GPUWarpMappingAttr::getMappingId() const {
83  return static_cast<int64_t>(getWarp());
84 }
85 
86 bool GPUWarpMappingAttr::isLinearMapping() const {
87  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
88 }
89 
90 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
91  return isLinearMapping()
92  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
93  : getMappingId();
94 }
95 
96 int64_t GPUThreadMappingAttr::getMappingId() const {
97  return static_cast<int64_t>(getThread());
98 }
99 
100 bool GPUThreadMappingAttr::isLinearMapping() const {
101  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
102 }
103 
104 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
105  return isLinearMapping()
106  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
107  : getMappingId();
108 }
109 
110 int64_t GPULaneMappingAttr::getMappingId() const {
111  return static_cast<int64_t>(getLane());
112 }
113 
114 bool GPULaneMappingAttr::isLinearMapping() const {
115  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
116 }
117 
118 int64_t GPULaneMappingAttr::getRelativeIndex() const {
119  return isLinearMapping()
120  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
121  : getMappingId();
122 }
123 
124 int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125 
126 /// 8 4 0
127 /// Example mask : 0 0 0 1 1 0 1 0 0
128 ///
129 /// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130 /// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131 ///
132 /// Example mask : 0 0 0 1 1 0 1 0 0
133 /// Example filter: 0 0 0 0 1 1 1 1 1
134 /// Intersection : 0 0 0 0 1 0 1 0 0
135 /// PopCnt : 2
136 Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137  OpBuilder &b, Value physicalLinearMappingId) const {
138  Location loc = physicalLinearMappingId.getLoc();
139  Value mask =
140  arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
141  Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
142  Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143  filter = arith::SubIOp::create(b, loc, filter, one);
144  Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145  return math::CtPopOp::create(b, loc, filteredId);
146 }
147 
148 /// 8 4 0
149 /// Example mask : 0 0 0 1 1 0 1 0 0
150 ///
151 /// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
152 /// Logical id for e.g. 5 (2) constructs filter (1 << 5).
153 ///
154 /// Example mask : 0 0 0 1 1 0 1 0 0
155 /// Example filter: 0 0 0 1 0 0 0 0 0
156 /// Intersection : 0 0 0 1 0 0 0 0 0
157 /// Cmp : 1
158 Value GPUMappingMaskAttr::createIsActiveIdPredicate(
159  OpBuilder &b, Value physicalLinearMappingId) const {
160  Location loc = physicalLinearMappingId.getLoc();
161  Value mask =
162  arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
163  Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
164  Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165  Value filtered = arith::AndIOp::create(b, loc, mask, filter);
166  Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
167  return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
168  zero);
169 }
170 
171 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
172  return static_cast<int64_t>(getAddressSpace());
173 }
174 
175 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
176  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
177 }
178 
179 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
180  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // MMAMatrixType
185 //===----------------------------------------------------------------------===//
186 
188  StringRef operand) {
189  return Base::get(elementType.getContext(), shape, elementType, operand);
190 }
191 
194  ArrayRef<int64_t> shape, Type elementType,
195  StringRef operand) {
196  return Base::getChecked(emitError, elementType.getContext(), shape,
197  elementType, operand);
198 }
199 
200 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
201 
203  return getImpl()->getShape();
204 }
205 
206 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
207 
208 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
209 
211  return elementType.isF16() || elementType.isF32() ||
212  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
213  elementType.isInteger(32);
214 }
215 
216 LogicalResult
218  ArrayRef<int64_t> shape, Type elementType,
219  StringRef operand) {
220  if (operand != "AOp" && operand != "BOp" && operand != "COp")
221  return emitError() << "operand expected to be one of AOp, BOp or COp";
222 
223  if (shape.size() != 2)
224  return emitError() << "MMAMatrixType must have exactly two dimensions";
225 
226  if (!MMAMatrixType::isValidElementType(elementType))
227  return emitError()
228  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
229 
230  return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // GPUDialect
235 //===----------------------------------------------------------------------===//
236 
237 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
238  if (!memorySpace)
239  return false;
240  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241  return gpuAttr.getValue() == getWorkgroupAddressSpace();
242  return false;
243 }
244 
245 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246  Attribute memorySpace = type.getMemorySpace();
247  return isWorkgroupMemoryAddressSpace(memorySpace);
248 }
249 
250 bool GPUDialect::isKernel(Operation *op) {
251  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252  return static_cast<bool>(isKernelAttr);
253 }
254 
255 namespace {
256 /// This class defines the interface for handling inlining with gpu
257 /// operations.
258 struct GPUInlinerInterface : public DialectInlinerInterface {
260 
261  /// All gpu dialect ops can be inlined.
262  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
263  return true;
264  }
265 };
266 } // namespace
267 
268 void GPUDialect::initialize() {
269  addTypes<AsyncTokenType>();
270  addTypes<MMAMatrixType>();
271  addTypes<SparseDnTensorHandleType>();
272  addTypes<SparseSpMatHandleType>();
273  addTypes<SparseSpGEMMOpHandleType>();
274  addOperations<
275 #define GET_OP_LIST
276 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
277  >();
278  addAttributes<
279 #define GET_ATTRDEF_LIST
280 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
281  >();
282  addInterfaces<GPUInlinerInterface>();
283  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
284  TerminatorOp>();
285  declarePromisedInterfaces<
286  ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287  ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288  SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
289 }
290 
292  switch (kind) {
294  return "sparse.dntensor_handle";
296  return "sparse.spmat_handle";
298  return "sparse.spgemmop_handle";
299  }
300  llvm_unreachable("unknown sparse handle kind");
301  return "";
302 }
303 
305  // Parse the main keyword for the type.
306  StringRef keyword;
307  if (parser.parseKeyword(&keyword))
308  return Type();
309  MLIRContext *context = getContext();
310 
311  // Handle 'async token' types.
312  if (keyword == "async.token")
313  return AsyncTokenType::get(context);
314 
315  if (keyword == "mma_matrix") {
316  SMLoc beginLoc = parser.getNameLoc();
317 
318  // Parse '<'.
319  if (parser.parseLess())
320  return nullptr;
321 
322  // Parse the size and elementType.
323  SmallVector<int64_t> shape;
324  Type elementType;
325  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
326  parser.parseType(elementType))
327  return nullptr;
328 
329  // Parse ','
330  if (parser.parseComma())
331  return nullptr;
332 
333  // Parse operand.
334  std::string operand;
335  if (failed(parser.parseOptionalString(&operand)))
336  return nullptr;
337 
338  // Parse '>'.
339  if (parser.parseGreater())
340  return nullptr;
341 
343  parser.getEncodedSourceLoc(beginLoc)),
344  shape, elementType, operand);
345  }
346 
348  return SparseDnTensorHandleType::get(context);
350  return SparseSpMatHandleType::get(context);
352  return SparseSpGEMMOpHandleType::get(context);
353 
354  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
355  return Type();
356 }
357 // TODO: print refined type here. Notice that should be corresponding to the
358 // parser
359 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
360  TypeSwitch<Type>(type)
361  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
362  .Case<SparseDnTensorHandleType>([&](Type) {
364  })
365  .Case<SparseSpMatHandleType>(
367  .Case<SparseSpGEMMOpHandleType>([&](Type) {
369  })
370  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
371  os << "mma_matrix<";
372  auto shape = fragTy.getShape();
373  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
374  os << *dim << 'x';
375  os << shape.back() << 'x' << fragTy.getElementType();
376  os << ", \"" << fragTy.getOperand() << "\"" << '>';
377  })
378  .DefaultUnreachable("unexpected 'gpu' type kind");
379 }
380 
381 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
382  NamedAttribute attr) {
383  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
384  if (!array)
385  return op->emitOpError(Twine(attr.getName()) +
386  " must be a dense i32 array");
387  if (array.size() != 3)
388  return op->emitOpError(Twine(attr.getName()) +
389  " must contain exactly 3 elements");
390  return success();
391 }
392 
393 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
394  NamedAttribute attr) {
395  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
396  return verifyKnownLaunchSizeAttr(op, attr);
397  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
398  return verifyKnownLaunchSizeAttr(op, attr);
399  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
400  attr.getName() != getContainerModuleAttrName())
401  return success();
402 
403  auto module = dyn_cast<ModuleOp>(op);
404  if (!module)
405  return op->emitError("expected '")
406  << getContainerModuleAttrName() << "' attribute to be attached to '"
407  << ModuleOp::getOperationName() << '\'';
408 
409  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
410  // Ignore launches that are nested more or less deep than functions in the
411  // module we are currently checking.
412  if (!launchOp->getParentOp() ||
413  launchOp->getParentOp()->getParentOp() != module)
414  return success();
415 
416  // Ignore launch ops with missing attributes here. The errors will be
417  // reported by the verifiers of those ops.
418  if (!launchOp->getAttrOfType<SymbolRefAttr>(
419  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
420  return success();
421 
422  // Check that `launch_func` refers to a well-formed GPU kernel container.
423  StringAttr kernelContainerName = launchOp.getKernelModuleName();
424  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
425  if (!kernelContainer)
426  return launchOp.emitOpError()
427  << "kernel container '" << kernelContainerName.getValue()
428  << "' is undefined";
429 
430  // If the container is a GPU binary op return success.
431  if (isa<BinaryOp>(kernelContainer))
432  return success();
433 
434  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
435  if (!kernelModule)
436  return launchOp.emitOpError()
437  << "kernel module '" << kernelContainerName.getValue()
438  << "' is undefined";
439 
440  // Check that `launch_func` refers to a well-formed kernel function.
441  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
442  if (!kernelFunc)
443  return launchOp.emitOpError("kernel function '")
444  << launchOp.getKernel() << "' is undefined";
445  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
446  if (!kernelConvertedFunction) {
447  InFlightDiagnostic diag = launchOp.emitOpError()
448  << "referenced kernel '" << launchOp.getKernel()
449  << "' is not a function";
450  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
451  return diag;
452  }
453 
454  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
455  GPUDialect::getKernelFuncAttrName()))
456  return launchOp.emitOpError("kernel function is missing the '")
457  << GPUDialect::getKernelFuncAttrName() << "' attribute";
458 
459  // TODO: If the kernel isn't a GPU function (which happens during separate
460  // compilation), do not check type correspondence as it would require the
461  // verifier to be aware of the type conversion.
462  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
463  if (!kernelGPUFunction)
464  return success();
465 
466  unsigned actualNumArguments = launchOp.getNumKernelOperands();
467  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
468  if (expectedNumArguments != actualNumArguments)
469  return launchOp.emitOpError("got ")
470  << actualNumArguments << " kernel operands but expected "
471  << expectedNumArguments;
472 
473  auto functionType = kernelGPUFunction.getFunctionType();
474  for (unsigned i = 0; i < expectedNumArguments; ++i) {
475  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
476  return launchOp.emitOpError("type of function argument ")
477  << i << " does not match";
478  }
479  }
480 
481  return success();
482  });
483 
484  return walkResult.wasInterrupted() ? failure() : success();
485 }
486 
487 /// Parses an optional list of async operands with an optional leading keyword.
488 /// (`async`)? (`[` ssa-id-list `]`)?
489 ///
490 /// This method is used by the tablegen assembly format for async ops as well.
491 static ParseResult parseAsyncDependencies(
492  OpAsmParser &parser, Type &asyncTokenType,
494  auto loc = parser.getCurrentLocation();
495  if (succeeded(parser.parseOptionalKeyword("async"))) {
496  if (parser.getNumResults() == 0)
497  return parser.emitError(loc, "needs to be named when marked 'async'");
498  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
499  }
500  return parser.parseOperandList(asyncDependencies,
502 }
503 
504 /// Prints optional async dependencies with its leading keyword.
505 /// (`async`)? (`[` ssa-id-list `]`)?
506 // Used by the tablegen assembly format for several async ops.
508  Type asyncTokenType,
509  OperandRange asyncDependencies) {
510  if (asyncTokenType)
511  printer << "async";
512  if (asyncDependencies.empty())
513  return;
514  if (asyncTokenType)
515  printer << ' ';
516  printer << llvm::interleaved_array(asyncDependencies);
517 }
518 
519 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
520 /// Parses a GPU function memory attribution.
521 ///
522 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
523 /// (`private` `(` ssa-id-and-type-list `)`)?
524 ///
525 /// Note that this function parses only one of the two similar parts, with the
526 /// keyword provided as argument.
527 static ParseResult
528 parseAttributions(OpAsmParser &parser, StringRef keyword,
530  // If we could not parse the keyword, just assume empty list and succeed.
531  if (failed(parser.parseOptionalKeyword(keyword)))
532  return success();
533 
535  /*allowType=*/true);
536 }
537 
538 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
540  ArrayAttr attributes = {}) {
541  if (values.empty())
542  return;
543 
544  p << ' ' << keyword << '(';
545  llvm::interleaveComma(
546  llvm::enumerate(values), p, [&p, attributes](auto pair) {
547  BlockArgument v = pair.value();
548  p << v << " : " << v.getType();
549 
550  size_t attributionIndex = pair.index();
551  DictionaryAttr attrs;
552  if (attributes && attributionIndex < attributes.size())
553  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
554  if (attrs)
555  p.printOptionalAttrDict(attrs.getValue());
556  });
557  p << ')';
558 }
559 
560 /// Verifies a GPU function memory attribution.
561 static LogicalResult verifyAttributions(Operation *op,
562  ArrayRef<BlockArgument> attributions,
563  gpu::AddressSpace memorySpace) {
564  for (Value v : attributions) {
565  auto type = llvm::dyn_cast<MemRefType>(v.getType());
566  if (!type)
567  return op->emitOpError() << "expected memref type in attribution";
568 
569  // We can only verify the address space if it hasn't already been lowered
570  // from the AddressSpaceAttr to a target-specific numeric value.
571  auto addressSpace =
572  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
573  if (!addressSpace)
574  continue;
575  if (addressSpace.getValue() != memorySpace)
576  return op->emitOpError()
577  << "expected memory space " << stringifyAddressSpace(memorySpace)
578  << " in attribution";
579  }
580  return success();
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // AllReduceOp
585 //===----------------------------------------------------------------------===//
586 
587 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
588  Type resType) {
589  using Kind = gpu::AllReduceOperation;
590  if (llvm::is_contained(
591  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
592  opName)) {
593  if (!isa<FloatType>(resType))
594  return failure();
595  }
596 
597  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
598  Kind::AND, Kind::OR, Kind::XOR},
599  opName)) {
600  if (!isa<IntegerType>(resType))
601  return failure();
602  }
603 
604  return success();
605 }
606 
607 LogicalResult gpu::AllReduceOp::verifyRegions() {
608  if (getBody().empty() != getOp().has_value())
609  return emitError("expected either an op attribute or a non-empty body");
610  if (!getBody().empty()) {
611  if (getBody().getNumArguments() != 2)
612  return emitError("expected two region arguments");
613  for (auto argument : getBody().getArguments()) {
614  if (argument.getType() != getType())
615  return emitError("incorrect region argument type");
616  }
617  unsigned yieldCount = 0;
618  for (Block &block : getBody()) {
619  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
620  if (yield.getNumOperands() != 1)
621  return emitError("expected one gpu.yield operand");
622  if (yield.getOperand(0).getType() != getType())
623  return emitError("incorrect gpu.yield type");
624  ++yieldCount;
625  }
626  }
627  if (yieldCount == 0)
628  return emitError("expected gpu.yield op in region");
629  } else {
630  gpu::AllReduceOperation opName = *getOp();
631  if (failed(verifyReduceOpAndType(opName, getType()))) {
632  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
633  << "` reduction operation is not compatible with type "
634  << getType();
635  }
636  }
637 
638  return success();
639 }
640 
642  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
643  if (!launchOp)
644  return false;
645 
646  Region &body = launchOp.getBody();
647  assert(!body.empty() && "Invalid region");
648 
649  // Only convert ops in gpu::launch entry block for now.
650  return op->getBlock() == &body.front();
651 }
652 
653 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
654  if (!getUniform() && canMakeGroupOpUniform(*this)) {
655  setUniform(true);
656  return getResult();
657  }
658 
659  return nullptr;
660 }
661 
662 // TODO: Support optional custom attributes (without dialect prefix).
663 static ParseResult parseAllReduceOperation(AsmParser &parser,
664  AllReduceOperationAttr &attr) {
665  StringRef enumStr;
666  if (!parser.parseOptionalKeyword(&enumStr)) {
667  std::optional<AllReduceOperation> op =
668  gpu::symbolizeAllReduceOperation(enumStr);
669  if (!op)
670  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
671  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
672  }
673  return success();
674 }
675 
676 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
677  AllReduceOperationAttr attr) {
678  if (attr)
679  attr.print(printer);
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // SubgroupReduceOp
684 //===----------------------------------------------------------------------===//
685 
686 LogicalResult gpu::SubgroupReduceOp::verify() {
687  Type elemType = getType();
688  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
689  if (vecTy.isScalable())
690  return emitOpError() << "is not compatible with scalable vector types";
691 
692  elemType = vecTy.getElementType();
693  }
694 
695  gpu::AllReduceOperation opName = getOp();
696  if (failed(verifyReduceOpAndType(opName, elemType))) {
697  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
698  << "` reduction operation is not compatible with type "
699  << getType();
700  }
701 
702  auto clusterSize = getClusterSize();
703  if (clusterSize) {
704  uint32_t size = *clusterSize;
705  if (!llvm::isPowerOf2_32(size)) {
706  return emitOpError() << "cluster size " << size
707  << " is not a power of two";
708  }
709  }
710 
711  uint32_t stride = getClusterStride();
712  if (stride != 1 && !clusterSize) {
713  return emitOpError() << "cluster stride can only be specified if cluster "
714  "size is specified";
715  }
716  if (!llvm::isPowerOf2_32(stride)) {
717  return emitOpError() << "cluster stride " << stride
718  << " is not a power of two";
719  }
720 
721  return success();
722 }
723 
724 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
725  if (getClusterSize() == 1)
726  return getValue();
727 
728  if (!getUniform() && canMakeGroupOpUniform(*this)) {
729  setUniform(true);
730  return getResult();
731  }
732 
733  return nullptr;
734 }
735 
736 //===----------------------------------------------------------------------===//
737 // AsyncOpInterface
738 //===----------------------------------------------------------------------===//
739 
741  op->insertOperands(0, {token});
742  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
743  return;
744  auto attrName =
746  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
747 
748  // Async dependencies is the only variadic operand.
749  if (!sizeAttr)
750  return;
751 
752  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
753  ++sizes.front();
754  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // LaunchOp
759 //===----------------------------------------------------------------------===//
760 
761 void LaunchOp::build(OpBuilder &builder, OperationState &result,
762  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
763  Value getBlockSizeX, Value getBlockSizeY,
764  Value getBlockSizeZ, Value dynamicSharedMemorySize,
765  Type asyncTokenType, ValueRange asyncDependencies,
766  TypeRange workgroupAttributions,
767  TypeRange privateAttributions, Value clusterSizeX,
768  Value clusterSizeY, Value clusterSizeZ,
769  FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
770  OpBuilder::InsertionGuard g(builder);
771 
772  // Add a WorkGroup attribution attribute. This attribute is required to
773  // identify private attributions in the list of block argguments.
774  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
775  builder.getI64IntegerAttr(workgroupAttributions.size()));
776 
777  // Add Op operands.
778  result.addOperands(asyncDependencies);
779  if (asyncTokenType)
780  result.types.push_back(builder.getType<AsyncTokenType>());
781 
782  // Add grid and block sizes as op operands, followed by the data operands.
783  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
784  getBlockSizeY, getBlockSizeZ});
785  if (clusterSizeX)
786  result.addOperands(clusterSizeX);
787  if (clusterSizeY)
788  result.addOperands(clusterSizeY);
789  if (clusterSizeZ)
790  result.addOperands(clusterSizeZ);
791  if (dynamicSharedMemorySize)
792  result.addOperands(dynamicSharedMemorySize);
793 
794  // Add optional module and function attributes.
795  if (module)
796  result.addAttribute(getModuleAttrName(result.name), module);
797  if (function)
798  result.addAttribute(getFunctionAttrName(result.name), function);
799 
800  // Create a kernel body region with kNumConfigRegionAttributes + N memory
801  // attributions, where the first kNumConfigRegionAttributes arguments have
802  // `index` type and the rest have the same types as the data operands.
803  Region *kernelRegion = result.addRegion();
804  Block *body = builder.createBlock(kernelRegion);
805  // TODO: Allow passing in proper locations here.
806  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
807  body->addArgument(builder.getIndexType(), result.location);
808  // Add WorkGroup & Private attributions to the region arguments.
809  for (Type argTy : workgroupAttributions)
810  body->addArgument(argTy, result.location);
811  for (Type argTy : privateAttributions)
812  body->addArgument(argTy, result.location);
813  // Fill OperandSegmentSize Attribute.
814  SmallVector<int32_t, 11> segmentSizes(11, 1);
815  segmentSizes.front() = asyncDependencies.size();
816  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
817  segmentSizes[7] = clusterSizeX ? 1 : 0;
818  segmentSizes[8] = clusterSizeY ? 1 : 0;
819  segmentSizes[9] = clusterSizeZ ? 1 : 0;
820  result.addAttribute(getOperandSegmentSizeAttr(),
821  builder.getDenseI32ArrayAttr(segmentSizes));
822 }
823 
824 KernelDim3 LaunchOp::getBlockIds() {
825  assert(!getBody().empty() && "LaunchOp body must not be empty.");
826  auto args = getBody().getArguments();
827  return KernelDim3{args[0], args[1], args[2]};
828 }
829 
830 KernelDim3 LaunchOp::getThreadIds() {
831  assert(!getBody().empty() && "LaunchOp body must not be empty.");
832  auto args = getBody().getArguments();
833  return KernelDim3{args[3], args[4], args[5]};
834 }
835 
836 KernelDim3 LaunchOp::getGridSize() {
837  assert(!getBody().empty() && "LaunchOp body must not be empty.");
838  auto args = getBody().getArguments();
839  return KernelDim3{args[6], args[7], args[8]};
840 }
841 
843  assert(!getBody().empty() && "LaunchOp body must not be empty.");
844  auto args = getBody().getArguments();
845  return KernelDim3{args[9], args[10], args[11]};
846 }
847 
848 std::optional<KernelDim3> LaunchOp::getClusterIds() {
849  assert(!getBody().empty() && "LaunchOp body must not be empty.");
850  if (!hasClusterSize())
851  return std::nullopt;
852  auto args = getBody().getArguments();
853  return KernelDim3{args[12], args[13], args[14]};
854 }
855 
856 std::optional<KernelDim3> LaunchOp::getClusterSize() {
857  assert(!getBody().empty() && "LaunchOp body must not be empty.");
858  if (!hasClusterSize())
859  return std::nullopt;
860  auto args = getBody().getArguments();
861  return KernelDim3{args[15], args[16], args[17]};
862 }
863 
864 KernelDim3 LaunchOp::getGridSizeOperandValues() {
865  auto operands = getOperands().drop_front(getAsyncDependencies().size());
866  return KernelDim3{operands[0], operands[1], operands[2]};
867 }
868 
869 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
870  auto operands = getOperands().drop_front(getAsyncDependencies().size());
871  return KernelDim3{operands[3], operands[4], operands[5]};
872 }
873 
874 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
875  auto operands = getOperands().drop_front(getAsyncDependencies().size());
876  if (!hasClusterSize())
877  return std::nullopt;
878  return KernelDim3{operands[6], operands[7], operands[8]};
879 }
880 
881 LogicalResult LaunchOp::verify() {
882  if (!(hasClusterSize()) &&
883  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
884  return emitOpError() << "cluster size must be all present";
885  return success();
886 }
887 
888 LogicalResult LaunchOp::verifyRegions() {
889  // Kernel launch takes kNumConfigOperands leading operands for grid/block
890  // sizes and transforms them into kNumConfigRegionAttributes region arguments
891  // for block/thread identifiers and grid/block sizes.
892  if (!getBody().empty()) {
893  if (getBody().getNumArguments() <
894  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
895  return emitOpError("unexpected number of region arguments");
896  }
897 
898  // Verify Attributions Address Spaces.
899  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
900  GPUDialect::getWorkgroupAddressSpace())) ||
901  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
902  GPUDialect::getPrivateAddressSpace())))
903  return failure();
904 
905  // Block terminators without successors are expected to exit the kernel region
906  // and must be `gpu.terminator`.
907  for (Block &block : getBody()) {
908  if (block.empty())
909  continue;
910  if (block.back().getNumSuccessors() != 0)
911  continue;
912  if (!isa<gpu::TerminatorOp>(&block.back())) {
913  return block.back()
914  .emitError()
915  .append("expected '", gpu::TerminatorOp::getOperationName(),
916  "' or a terminator with successors")
917  .attachNote(getLoc())
918  .append("in '", LaunchOp::getOperationName(), "' body region");
919  }
920  }
921 
922  if (getNumResults() == 0 && getAsyncToken())
923  return emitOpError("needs to be named when async keyword is specified");
924 
925  return success();
926 }
927 
928 // Pretty-print the kernel grid/block size assignment as
929 // (%iter-x, %iter-y, %iter-z) in
930 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
931 // where %size-* and %iter-* will correspond to the body region arguments.
933  KernelDim3 operands, KernelDim3 ids) {
934  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
935  p << size.x << " = " << operands.x << ", ";
936  p << size.y << " = " << operands.y << ", ";
937  p << size.z << " = " << operands.z << ')';
938 }
939 
940 void LaunchOp::print(OpAsmPrinter &p) {
941  if (getAsyncToken()) {
942  p << " async";
943  if (!getAsyncDependencies().empty())
944  p << " [" << getAsyncDependencies() << ']';
945  }
946  // Print the launch configuration.
947  if (hasClusterSize()) {
948  p << ' ' << getClustersKeyword();
949  printSizeAssignment(p, getClusterSize().value(),
950  getClusterSizeOperandValues().value(),
951  getClusterIds().value());
952  }
953  p << ' ' << getBlocksKeyword();
954  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
955  getBlockIds());
956  p << ' ' << getThreadsKeyword();
957  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
958  getThreadIds());
959  if (getDynamicSharedMemorySize())
960  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
961  << getDynamicSharedMemorySize();
962 
963  // Print optional module attribute.
964  StringRef moduleAttrName = getModuleAttrName();
965  if (auto module = getModule()) {
966  p << ' ' << moduleAttrName << '(';
967  p.printSymbolName(*module);
968  p << ')';
969  }
970  // Print optional function attribute.
971  StringRef functionAttrName = getFunctionAttrName();
972  if (auto function = getFunction()) {
973  p << ' ' << functionAttrName << '(';
974  p.printSymbolName(*function);
975  p << ')';
976  }
977 
978  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
979  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
980 
981  p << ' ';
982 
983  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
984  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
985  LaunchOp::getOperandSegmentSizeAttr(),
986  getNumWorkgroupAttributionsAttrName(),
987  moduleAttrName, functionAttrName});
988 }
989 
990 // Parse the size assignment blocks for blocks and threads. These have the form
991 // (%region_arg, %region_arg, %region_arg) in
992 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
993 // where %region_arg are percent-identifiers for the region arguments to be
994 // introduced further (SSA defs), and %operand are percent-identifiers for the
995 // SSA value uses.
996 static ParseResult
1001  assert(indices.size() == 3 && "space for three indices expected");
1004  /*allowResultNumber=*/false) ||
1005  parser.parseKeyword("in") || parser.parseLParen())
1006  return failure();
1007  std::move(args.begin(), args.end(), indices.begin());
1008 
1009  for (int i = 0; i < 3; ++i) {
1010  if (i != 0 && parser.parseComma())
1011  return failure();
1012  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1013  parser.parseEqual() || parser.parseOperand(sizes[i]))
1014  return failure();
1015  }
1016 
1017  return parser.parseRParen();
1018 }
1019 
1020 /// Parses a Launch operation.
1021 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1022 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1023 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1024 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1025 /// (`dynamic_shared_memory_size` ssa-use)?
1026 /// (`module(` symbol-ref-id `)`)?
1027 /// (`function(` symbol-ref-id `)`)?
1028 /// memory-attribution
1029 /// region attr-dict?
1030 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1031 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1032  // Sizes of the grid and block.
1034  sizes(LaunchOp::kNumConfigOperands);
1035 
1036  // Region arguments to be created.
1038  LaunchOp::kNumConfigRegionAttributes);
1039 
1040  // Parse optional async dependencies.
1042  Type asyncTokenType;
1043  if (failed(
1044  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1045  parser.resolveOperands(asyncDependencies, asyncTokenType,
1046  result.operands))
1047  return failure();
1048  if (parser.getNumResults() > 0)
1049  result.types.push_back(asyncTokenType);
1050 
1051  bool hasCluster = false;
1052  if (succeeded(
1053  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
1054  hasCluster = true;
1055  sizes.resize(9);
1056  regionArgs.resize(18);
1057  }
1059  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1060 
1061  // Last three segment assigns the cluster size. In the region argument
1062  // list, this is last 6 arguments.
1063  if (hasCluster) {
1064  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
1065  regionArgsRef.slice(15, 3),
1066  regionArgsRef.slice(12, 3)))
1067  return failure();
1068  }
1069  // Parse the size assignment segments: the first segment assigns grid sizes
1070  // and defines values for block identifiers; the second segment assigns block
1071  // sizes and defines values for thread identifiers. In the region argument
1072  // list, identifiers precede sizes, and block-related values precede
1073  // thread-related values.
1074  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1075  parseSizeAssignment(parser, sizesRef.take_front(3),
1076  regionArgsRef.slice(6, 3),
1077  regionArgsRef.slice(0, 3)) ||
1078  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1079  parseSizeAssignment(parser, sizesRef.drop_front(3),
1080  regionArgsRef.slice(9, 3),
1081  regionArgsRef.slice(3, 3)) ||
1082  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1083  result.operands))
1084  return failure();
1085 
1086  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1087  bool hasDynamicSharedMemorySize = false;
1088  if (!parser.parseOptionalKeyword(
1089  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1090  hasDynamicSharedMemorySize = true;
1091  if (parser.parseOperand(dynamicSharedMemorySize) ||
1092  parser.resolveOperand(dynamicSharedMemorySize,
1093  parser.getBuilder().getI32Type(),
1094  result.operands))
1095  return failure();
1096  }
1097 
1098  // Parse optional module attribute.
1099  StringRef moduleAttrName = getModuleAttrName(result.name);
1100  if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1101  FlatSymbolRefAttr moduleSymbol;
1102  if (parser.parseLParen() ||
1103  parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1104  result.attributes) ||
1105  parser.parseRParen())
1106  return failure();
1107  }
1108  // Parse optional function attribute.
1109  StringRef functionAttrName = getFunctionAttrName(result.name);
1110  if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1111  FlatSymbolRefAttr funcSymbol;
1112  if (parser.parseLParen() ||
1113  parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1114  result.attributes) ||
1115  parser.parseRParen())
1116  return failure();
1117  }
1118 
1119  // Create the region arguments, it has kNumConfigRegionAttributes arguments
1120  // that correspond to block/thread identifiers and grid/block sizes, all
1121  // having `index` type, a variadic number of WorkGroup Attributions and
1122  // a variadic number of Private Attributions. The number of WorkGroup
1123  // Attributions is stored in the attr with name:
1124  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1125  Type index = parser.getBuilder().getIndexType();
1127  LaunchOp::kNumConfigRegionAttributes + 6, index);
1128 
1129  SmallVector<OpAsmParser::Argument> regionArguments;
1130  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1132  arg.ssaName = std::get<0>(ssaValueAndType);
1133  arg.type = std::get<1>(ssaValueAndType);
1134  regionArguments.push_back(arg);
1135  }
1136 
1137  Builder &builder = parser.getBuilder();
1138  // Parse workgroup memory attributions.
1139  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1140  regionArguments)))
1141  return failure();
1142 
1143  // Store the number of operands we just parsed as the number of workgroup
1144  // memory attributions.
1145  unsigned numWorkgroupAttrs = regionArguments.size() -
1146  LaunchOp::kNumConfigRegionAttributes -
1147  (hasCluster ? 6 : 0);
1148  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1149  builder.getI64IntegerAttr(numWorkgroupAttrs));
1150 
1151  // Parse private memory attributions.
1152  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1153  regionArguments)))
1154  return failure();
1155 
1156  // Introduce the body region and parse it. The region has
1157  // kNumConfigRegionAttributes arguments that correspond to
1158  // block/thread identifiers and grid/block sizes, all having `index` type.
1159  Region *body = result.addRegion();
1160  if (parser.parseRegion(*body, regionArguments) ||
1161  parser.parseOptionalAttrDict(result.attributes))
1162  return failure();
1163 
1164  SmallVector<int32_t, 11> segmentSizes(11, 1);
1165  segmentSizes.front() = asyncDependencies.size();
1166 
1167  if (!hasCluster) {
1168  segmentSizes[7] = 0;
1169  segmentSizes[8] = 0;
1170  segmentSizes[9] = 0;
1171  }
1172  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1173  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1174  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1175  return success();
1176 }
1177 
1178 /// Simplify the gpu.launch when the range of a thread or block ID is
1179 /// trivially known to be one.
1180 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1182  LogicalResult matchAndRewrite(LaunchOp op,
1183  PatternRewriter &rewriter) const override {
1184  // If the range implies a single value for `id`, replace `id`'s uses by
1185  // zero.
1186  Value zero;
1187  bool simplified = false;
1188  auto constPropIdUses = [&](Value id, Value size) {
1189  // Check if size is trivially one.
1190  if (!matchPattern(size, m_One()))
1191  return;
1192  if (id.getUses().empty())
1193  return;
1194  if (!simplified) {
1195  // Create a zero value the first time.
1196  OpBuilder::InsertionGuard guard(rewriter);
1197  rewriter.setInsertionPointToStart(&op.getBody().front());
1198  zero =
1199  arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1200  }
1201  rewriter.replaceAllUsesWith(id, zero);
1202  simplified = true;
1203  };
1204  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1205  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1206  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1207  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1208  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1209  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1210 
1211  return success(simplified);
1212  }
1213 };
1214 
1215 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1216  MLIRContext *context) {
1217  rewrites.add<FoldLaunchArguments>(context);
1218 }
1219 
1220 /// Adds a new block argument that corresponds to buffers located in
1221 /// workgroup memory.
1222 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1223  auto attrName = getNumWorkgroupAttributionsAttrName();
1224  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1225  (*this)->setAttr(attrName,
1226  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1227  return getBody().insertArgument(
1228  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1229 }
1230 
1231 /// Adds a new block argument that corresponds to buffers located in
1232 /// private memory.
1233 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1234  // Buffers on the private memory always come after buffers on the workgroup
1235  // memory.
1236  return getBody().addArgument(type, loc);
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // LaunchFuncOp
1241 //===----------------------------------------------------------------------===//
1242 
1243 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1244  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1245  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1246  ValueRange kernelOperands, Type asyncTokenType,
1247  ValueRange asyncDependencies,
1248  std::optional<KernelDim3> clusterSize) {
1249  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1250  "expected a symbol reference with a single nested reference");
1251  result.addOperands(asyncDependencies);
1252  if (asyncTokenType)
1253  result.types.push_back(builder.getType<AsyncTokenType>());
1254 
1255  // Add grid and block sizes as op operands, followed by the data operands.
1256  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1257  getBlockSize.y, getBlockSize.z});
1258  if (clusterSize.has_value())
1259  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1260  if (dynamicSharedMemorySize)
1261  result.addOperands(dynamicSharedMemorySize);
1262  result.addOperands(kernelOperands);
1263 
1264  Properties &prop = result.getOrAddProperties<Properties>();
1265  prop.kernel = kernelSymbol;
1266  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1267  // Initialize the segment sizes to 1.
1268  llvm::fill(prop.operandSegmentSizes, 1);
1269  prop.operandSegmentSizes[0] = asyncDependencies.size();
1270  if (!clusterSize.has_value()) {
1271  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1272  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1273  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1274  }
1275  prop.operandSegmentSizes[segmentSizesLen - 3] =
1276  dynamicSharedMemorySize ? 1 : 0;
1277  prop.operandSegmentSizes[segmentSizesLen - 2] =
1278  static_cast<int32_t>(kernelOperands.size());
1279  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1280 }
1281 
1282 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1283  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1284  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1285  ValueRange kernelOperands, Type asyncTokenType,
1286  ValueRange asyncDependencies,
1287  std::optional<KernelDim3> clusterSize) {
1288  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1289  auto kernelSymbol =
1290  SymbolRefAttr::get(kernelModule.getNameAttr(),
1291  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1292  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1293  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1294  asyncDependencies, clusterSize);
1295 }
1296 
1297 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1298  SymbolRefAttr kernel, KernelDim3 gridSize,
1299  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1300  ValueRange kernelOperands, Value asyncObject,
1301  std::optional<KernelDim3> clusterSize) {
1302  // Add grid and block sizes as op operands, followed by the data operands.
1303  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1304  getBlockSize.y, getBlockSize.z});
1305  if (clusterSize.has_value())
1306  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1307  if (dynamicSharedMemorySize)
1308  result.addOperands(dynamicSharedMemorySize);
1309  result.addOperands(kernelOperands);
1310  if (asyncObject)
1311  result.addOperands(asyncObject);
1312  Properties &prop = result.getOrAddProperties<Properties>();
1313  prop.kernel = kernel;
1314  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1315  // Initialize the segment sizes to 1.
1316  llvm::fill(prop.operandSegmentSizes, 1);
1317  prop.operandSegmentSizes[0] = 0;
1318  if (!clusterSize.has_value()) {
1319  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1320  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1321  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1322  }
1323  prop.operandSegmentSizes[segmentSizesLen - 3] =
1324  dynamicSharedMemorySize ? 1 : 0;
1325  prop.operandSegmentSizes[segmentSizesLen - 2] =
1326  static_cast<int32_t>(kernelOperands.size());
1327  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1328 }
1329 
1330 StringAttr LaunchFuncOp::getKernelModuleName() {
1331  return getKernel().getRootReference();
1332 }
1333 
1334 StringAttr LaunchFuncOp::getKernelName() {
1335  return getKernel().getLeafReference();
1336 }
1337 
1338 unsigned LaunchFuncOp::getNumKernelOperands() {
1339  return getKernelOperands().size();
1340 }
1341 
1342 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1343  return getKernelOperands()[i];
1344 }
1345 
1346 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1347  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1348  return KernelDim3{operands[0], operands[1], operands[2]};
1349 }
1350 
1351 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1352  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1353  return KernelDim3{operands[3], operands[4], operands[5]};
1354 }
1355 
1356 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1357  assert(hasClusterSize() &&
1358  "cluster size is not set, check hasClusterSize() first");
1359  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1360  return KernelDim3{operands[6], operands[7], operands[8]};
1361 }
1362 
1363 LogicalResult LaunchFuncOp::verify() {
1364  auto module = (*this)->getParentOfType<ModuleOp>();
1365  if (!module)
1366  return emitOpError("expected to belong to a module");
1367 
1368  if (!module->getAttrOfType<UnitAttr>(
1369  GPUDialect::getContainerModuleAttrName()))
1370  return emitOpError("expected the closest surrounding module to have the '" +
1371  GPUDialect::getContainerModuleAttrName() +
1372  "' attribute");
1373 
1374  if (hasClusterSize()) {
1375  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1376  getClusterSizeZ().getType() != getClusterSizeX().getType())
1377  return emitOpError()
1378  << "expects types of the cluster dimensions must be the same";
1379  }
1380 
1381  return success();
1382 }
1383 
1384 static ParseResult
1386  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1387  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1388  if (succeeded(parser.parseOptionalColon())) {
1389  if (parser.parseType(dimTy))
1390  return failure();
1391  } else {
1392  dimTy = IndexType::get(parser.getContext());
1393  }
1394  if (clusterValue.has_value()) {
1395  clusterXTy = clusterYTy = clusterZTy = dimTy;
1396  }
1397  return success();
1398 }
1399 
1400 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1401  Value clusterValue, Type clusterXTy,
1402  Type clusterYTy, Type clusterZTy) {
1403  if (!dimTy.isIndex())
1404  printer << ": " << dimTy;
1405 }
1406 
1407 static ParseResult parseLaunchFuncOperands(
1408  OpAsmParser &parser,
1410  SmallVectorImpl<Type> &argTypes) {
1411  if (parser.parseOptionalKeyword("args"))
1412  return success();
1413 
1414  auto parseElement = [&]() -> ParseResult {
1415  return failure(parser.parseOperand(argNames.emplace_back()) ||
1416  parser.parseColonType(argTypes.emplace_back()));
1417  };
1418 
1420  parseElement, " in argument list");
1421 }
1422 
1424  OperandRange operands, TypeRange types) {
1425  if (operands.empty())
1426  return;
1427  printer << "args(";
1428  llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1429  [&](const auto &pair) {
1430  auto [operand, type] = pair;
1431  printer << operand << " : " << type;
1432  });
1433  printer << ")";
1434 }
1435 
1436 //===----------------------------------------------------------------------===//
1437 // ShuffleOp
1438 //===----------------------------------------------------------------------===//
1439 
1440 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1441  int32_t offset, int32_t width, ShuffleMode mode) {
1442  build(builder, result, value,
1443  arith::ConstantOp::create(builder, result.location,
1444  builder.getI32IntegerAttr(offset)),
1445  arith::ConstantOp::create(builder, result.location,
1446  builder.getI32IntegerAttr(width)),
1447  mode);
1448 }
1449 
1450 //===----------------------------------------------------------------------===//
1451 // RotateOp
1452 //===----------------------------------------------------------------------===//
1453 
1454 LogicalResult RotateOp::verify() {
1455  uint32_t offset = getOffset();
1456  uint32_t width = getWidth();
1457 
1458  if (offset >= width) {
1459  return emitOpError() << "offset must be in the range [0, " << width << ")";
1460  }
1461 
1462  return success();
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // BarrierOp
1467 //===----------------------------------------------------------------------===//
1468 
1469 namespace {
1470 
1471 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1472 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1473  PatternRewriter &rewriter) {
1474  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1475  rewriter.eraseOp(op);
1476  return success();
1477  }
1478  return failure();
1479 }
1480 
1481 } // end anonymous namespace
1482 
1483 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1484  MLIRContext *context) {
1485  results.add(eraseRedundantGpuBarrierOps);
1486 }
1487 
1488 //===----------------------------------------------------------------------===//
1489 // GPUFuncOp
1490 //===----------------------------------------------------------------------===//
1491 
1492 /// Adds a new block argument that corresponds to buffers located in
1493 /// workgroup memory.
1494 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1495  auto attrName = getNumWorkgroupAttributionsAttrName();
1496  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1497  (*this)->setAttr(attrName,
1498  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1499  return getBody().insertArgument(
1500  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1501 }
1502 
1503 /// Adds a new block argument that corresponds to buffers located in
1504 /// private memory.
1505 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1506  // Buffers on the private memory always come after buffers on the workgroup
1507  // memory.
1508  return getBody().addArgument(type, loc);
1509 }
1510 
1511 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1512  StringRef name, FunctionType type,
1513  TypeRange workgroupAttributions,
1514  TypeRange privateAttributions,
1515  ArrayRef<NamedAttribute> attrs) {
1516  OpBuilder::InsertionGuard g(builder);
1517 
1519  builder.getStringAttr(name));
1520  result.addAttribute(getFunctionTypeAttrName(result.name),
1521  TypeAttr::get(type));
1522  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1523  builder.getI64IntegerAttr(workgroupAttributions.size()));
1524  result.addAttributes(attrs);
1525  Region *body = result.addRegion();
1526  Block *entryBlock = builder.createBlock(body);
1527 
1528  // TODO: Allow passing in proper locations here.
1529  for (Type argTy : type.getInputs())
1530  entryBlock->addArgument(argTy, result.location);
1531  for (Type argTy : workgroupAttributions)
1532  entryBlock->addArgument(argTy, result.location);
1533  for (Type argTy : privateAttributions)
1534  entryBlock->addArgument(argTy, result.location);
1535 }
1536 
1537 /// Parses a GPU function memory attribution.
1538 ///
1539 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1540 /// (`private` `(` ssa-id-and-type-list `)`)?
1541 ///
1542 /// Note that this function parses only one of the two similar parts, with the
1543 /// keyword provided as argument.
1544 static ParseResult
1545 parseAttributions(OpAsmParser &parser, StringRef keyword,
1547  Attribute &attributionAttrs) {
1548  // If we could not parse the keyword, just assume empty list and succeed.
1549  if (failed(parser.parseOptionalKeyword(keyword)))
1550  return success();
1551 
1552  size_t existingArgs = args.size();
1553  ParseResult result =
1555  /*allowType=*/true, /*allowAttrs=*/true);
1556  if (failed(result))
1557  return result;
1558 
1559  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1560  [](const OpAsmParser::Argument &arg) -> bool {
1561  return arg.attrs && !arg.attrs.empty();
1562  });
1563  if (!hadAttrs) {
1564  attributionAttrs = nullptr;
1565  return result;
1566  }
1567 
1568  Builder &builder = parser.getBuilder();
1569  SmallVector<Attribute> attributionAttrsVec;
1570  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1571  if (!argument.attrs)
1572  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1573  else
1574  attributionAttrsVec.push_back(argument.attrs);
1575  }
1576  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1577  return result;
1578 }
1579 
1580 /// Parses a GPU function.
1581 ///
1582 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1583 /// (`->` function-result-list)? memory-attribution `kernel`?
1584 /// function-attributes? region
1585 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1587  SmallVector<DictionaryAttr> resultAttrs;
1588  SmallVector<Type> resultTypes;
1589  bool isVariadic;
1590 
1591  // Parse the function name.
1592  StringAttr nameAttr;
1593  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1594  result.attributes))
1595  return failure();
1596 
1597  auto signatureLocation = parser.getCurrentLocation();
1599  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1600  resultAttrs)))
1601  return failure();
1602 
1603  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1604  return parser.emitError(signatureLocation)
1605  << "gpu.func requires named arguments";
1606 
1607  // Construct the function type. More types will be added to the region, but
1608  // not to the function type.
1609  Builder &builder = parser.getBuilder();
1610 
1611  SmallVector<Type> argTypes;
1612  for (auto &arg : entryArgs)
1613  argTypes.push_back(arg.type);
1614  auto type = builder.getFunctionType(argTypes, resultTypes);
1615  result.addAttribute(getFunctionTypeAttrName(result.name),
1616  TypeAttr::get(type));
1617 
1619  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1620  getResAttrsAttrName(result.name));
1621 
1622  Attribute workgroupAttributionAttrs;
1623  // Parse workgroup memory attributions.
1624  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1625  entryArgs, workgroupAttributionAttrs)))
1626  return failure();
1627 
1628  // Store the number of operands we just parsed as the number of workgroup
1629  // memory attributions.
1630  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1631  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1632  builder.getI64IntegerAttr(numWorkgroupAttrs));
1633  if (workgroupAttributionAttrs)
1634  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1635  workgroupAttributionAttrs);
1636 
1637  Attribute privateAttributionAttrs;
1638  // Parse private memory attributions.
1639  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1640  entryArgs, privateAttributionAttrs)))
1641  return failure();
1642  if (privateAttributionAttrs)
1643  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1644  privateAttributionAttrs);
1645 
1646  // Parse the kernel attribute if present.
1647  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1648  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1649  builder.getUnitAttr());
1650 
1651  // Parse attributes.
1653  return failure();
1654 
1655  // Parse the region. If no argument names were provided, take all names
1656  // (including those of attributions) from the entry block.
1657  auto *body = result.addRegion();
1658  return parser.parseRegion(*body, entryArgs);
1659 }
1660 
1661 void GPUFuncOp::print(OpAsmPrinter &p) {
1662  p << ' ';
1663  p.printSymbolName(getName());
1664 
1665  FunctionType type = getFunctionType();
1666  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1667  /*isVariadic=*/false,
1668  type.getResults());
1669 
1670  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1671  getWorkgroupAttribAttrs().value_or(nullptr));
1672  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1673  getPrivateAttribAttrs().value_or(nullptr));
1674  if (isKernel())
1675  p << ' ' << getKernelKeyword();
1676 
1678  p, *this,
1679  {getNumWorkgroupAttributionsAttrName(),
1680  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1681  getArgAttrsAttrName(), getResAttrsAttrName(),
1682  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1683  p << ' ';
1684  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1685 }
1686 
1687 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1688  StringAttr attrName) {
1689  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1690  if (!allAttrs || index >= allAttrs.size())
1691  return DictionaryAttr();
1692  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1693 }
1694 
1695 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1696  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1697 }
1698 
1699 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1700  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1701 }
1702 
1703 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1704  DictionaryAttr value, StringAttr attrName) {
1705  MLIRContext *ctx = op.getContext();
1706  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1707  SmallVector<Attribute> elements;
1708  if (allAttrs)
1709  elements.append(allAttrs.begin(), allAttrs.end());
1710  while (elements.size() <= index)
1711  elements.push_back(DictionaryAttr::get(ctx));
1712  if (!value)
1713  elements[index] = DictionaryAttr::get(ctx);
1714  else
1715  elements[index] = value;
1716  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1717  op->setAttr(attrName, newValue);
1718 }
1719 
1720 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1721  DictionaryAttr value) {
1722  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1723 }
1724 
1725 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1726  DictionaryAttr value) {
1727  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1728 }
1729 
1730 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1731  StringAttr name, StringAttr attrsName) {
1732  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1733  if (!dict)
1734  return Attribute();
1735  return dict.get(name);
1736 }
1737 
1738 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1739  StringAttr name) {
1740  assert(index < getNumWorkgroupAttributions() &&
1741  "index must map to a workgroup attribution");
1742  return getAttributionAttr(*this, index, name,
1743  getWorkgroupAttribAttrsAttrName());
1744 }
1745 
1746 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1747  StringAttr name) {
1748  assert(index < getNumPrivateAttributions() &&
1749  "index must map to a private attribution");
1750  return getAttributionAttr(*this, index, name,
1751  getPrivateAttribAttrsAttrName());
1752 }
1753 
1754 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1755  Attribute value, StringAttr attrsName) {
1756  MLIRContext *ctx = op.getContext();
1758  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1759  if (oldDict)
1760  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1761 
1762  bool found = false;
1763  bool mustSort = true;
1764  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1765  if (elems[i].getName() == name) {
1766  found = true;
1767  if (!value) {
1768  std::swap(elems[i], elems[elems.size() - 1]);
1769  elems.pop_back();
1770  } else {
1771  mustSort = false;
1772  elems[i] = NamedAttribute(elems[i].getName(), value);
1773  }
1774  break;
1775  }
1776  }
1777  if (!found) {
1778  if (!value)
1779  return;
1780  elems.emplace_back(name, value);
1781  }
1782  if (mustSort) {
1783  DictionaryAttr::sortInPlace(elems);
1784  }
1785  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1786  setAttributionAttrs(op, index, newDict, attrsName);
1787 }
1788 
1789 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1790  Attribute value) {
1791  assert(index < getNumWorkgroupAttributions() &&
1792  "index must map to a workgroup attribution");
1793  setAttributionAttr(*this, index, name, value,
1794  getWorkgroupAttribAttrsAttrName());
1795 }
1796 
1797 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1798  Attribute value) {
1799  assert(index < getNumPrivateAttributions() &&
1800  "index must map to a private attribution");
1801  setAttributionAttr(*this, index, name, value,
1802  getPrivateAttribAttrsAttrName());
1803 }
1804 
1805 LogicalResult GPUFuncOp::verifyType() {
1806  if (isKernel() && getFunctionType().getNumResults() != 0)
1807  return emitOpError() << "expected void return type for kernel function";
1808 
1809  return success();
1810 }
1811 
1812 /// Verifies the body of the function.
1813 LogicalResult GPUFuncOp::verifyBody() {
1814  if (empty())
1815  return emitOpError() << "expected body with at least one block";
1816  unsigned numFuncArguments = getNumArguments();
1817  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1818  unsigned numBlockArguments = front().getNumArguments();
1819  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1820  return emitOpError() << "expected at least "
1821  << numFuncArguments + numWorkgroupAttributions
1822  << " arguments to body region";
1823 
1824  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1825  for (unsigned i = 0; i < numFuncArguments; ++i) {
1826  Type blockArgType = front().getArgument(i).getType();
1827  if (funcArgTypes[i] != blockArgType)
1828  return emitOpError() << "expected body region argument #" << i
1829  << " to be of type " << funcArgTypes[i] << ", got "
1830  << blockArgType;
1831  }
1832 
1833  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1834  GPUDialect::getWorkgroupAddressSpace())) ||
1835  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1836  GPUDialect::getPrivateAddressSpace())))
1837  return failure();
1838 
1839  return success();
1840 }
1841 
1842 //===----------------------------------------------------------------------===//
1843 // ReturnOp
1844 //===----------------------------------------------------------------------===//
1845 
1846 LogicalResult gpu::ReturnOp::verify() {
1847  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1848 
1849  FunctionType funType = function.getFunctionType();
1850 
1851  if (funType.getNumResults() != getOperands().size())
1852  return emitOpError()
1853  .append("expected ", funType.getNumResults(), " result operands")
1854  .attachNote(function.getLoc())
1855  .append("return type declared here");
1856 
1857  for (const auto &pair : llvm::enumerate(
1858  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1859  auto [type, operand] = pair.value();
1860  if (type != operand.getType())
1861  return emitOpError() << "unexpected type `" << operand.getType()
1862  << "' for operand #" << pair.index();
1863  }
1864  return success();
1865 }
1866 
1867 //===----------------------------------------------------------------------===//
1868 // GPUModuleOp
1869 //===----------------------------------------------------------------------===//
1870 
1871 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1872  StringRef name, ArrayAttr targets,
1873  Attribute offloadingHandler) {
1874  result.addRegion()->emplaceBlock();
1875  Properties &props = result.getOrAddProperties<Properties>();
1876  if (targets)
1877  props.targets = targets;
1878  props.setSymName(builder.getStringAttr(name));
1879  props.offloadingHandler = offloadingHandler;
1880 }
1881 
1882 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1883  StringRef name, ArrayRef<Attribute> targets,
1884  Attribute offloadingHandler) {
1885  build(builder, result, name,
1886  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1887  offloadingHandler);
1888 }
1889 
1890 bool GPUModuleOp::hasTarget(Attribute target) {
1891  if (ArrayAttr targets = getTargetsAttr())
1892  return llvm::count(targets.getValue(), target);
1893  return false;
1894 }
1895 
1896 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1897  ArrayAttr &targetsAttr = getProperties().targets;
1898  SmallVector<Attribute> targetsVector(targets);
1899  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1900 }
1901 
1902 LogicalResult GPUModuleOp::verify() {
1903  auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1904 
1905  if (!targets)
1906  return success();
1907 
1908  for (auto target : targets) {
1909  if (auto verifyTargetAttr =
1910  llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1911  if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1912  return failure();
1913  }
1914  }
1915  return success();
1916 }
1917 
1918 //===----------------------------------------------------------------------===//
1919 // GPUBinaryOp
1920 //===----------------------------------------------------------------------===//
1921 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1922  Attribute offloadingHandler, ArrayAttr objects) {
1923  auto &properties = result.getOrAddProperties<Properties>();
1924  result.attributes.push_back(builder.getNamedAttr(
1925  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1926  properties.objects = objects;
1927  if (offloadingHandler)
1928  properties.offloadingHandler = offloadingHandler;
1929  else
1930  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1931 }
1932 
1933 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1934  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1935  build(builder, result, name, offloadingHandler,
1936  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1937 }
1938 
1939 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1940  Attribute &offloadingHandler) {
1941  if (succeeded(parser.parseOptionalLess())) {
1942  if (parser.parseAttribute(offloadingHandler))
1943  return failure();
1944  if (parser.parseGreater())
1945  return failure();
1946  }
1947  if (!offloadingHandler)
1948  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1949  return success();
1950 }
1951 
1953  Attribute offloadingHandler) {
1954  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1955  printer << '<' << offloadingHandler << '>';
1956 }
1957 
1958 //===----------------------------------------------------------------------===//
1959 // GPUMemcpyOp
1960 //===----------------------------------------------------------------------===//
1961 
1962 LogicalResult MemcpyOp::verify() {
1963  auto srcType = getSrc().getType();
1964  auto dstType = getDst().getType();
1965 
1966  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1967  return emitOpError("arguments have incompatible element type");
1968 
1969  if (failed(verifyCompatibleShape(srcType, dstType)))
1970  return emitOpError("arguments have incompatible shape");
1971 
1972  return success();
1973 }
1974 
1975 namespace {
1976 
1977 /// Erases a common case of copy ops where a destination value is used only by
1978 /// the copy op, alloc and dealloc ops.
1979 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1980  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1981 
1982  LogicalResult matchAndRewrite(MemcpyOp op,
1983  PatternRewriter &rewriter) const override {
1984  Value dest = op.getDst();
1985  Operation *destDefOp = dest.getDefiningOp();
1986  // `dest` must be defined by an op having Allocate memory effect in order to
1987  // perform the folding.
1988  if (!destDefOp ||
1989  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
1990  return failure();
1991  // We can erase `op` iff `dest` has no other use apart from its
1992  // use by `op` and dealloc ops.
1993  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
1994  return user != op &&
1995  !hasSingleEffect<MemoryEffects::Free>(user, dest);
1996  }))
1997  return failure();
1998  // We can perform the folding if and only if op has a single async
1999  // dependency and produces an async token as result, or if it does not have
2000  // any async dependency and does not produce any async token result.
2001  if (op.getAsyncDependencies().size() > 1 ||
2002  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2003  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2004  return failure();
2005  rewriter.replaceOp(op, op.getAsyncDependencies());
2006  return success();
2007  }
2008 };
2009 
2010 } // end anonymous namespace
2011 
2012 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2013  MLIRContext *context) {
2014  results.add<EraseTrivialCopyOp>(context);
2015 }
2016 
2017 //===----------------------------------------------------------------------===//
2018 // GPU_SubgroupMmaLoadMatrixOp
2019 //===----------------------------------------------------------------------===//
2020 
2021 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2022  auto srcType = getSrcMemref().getType();
2023  auto resType = getRes().getType();
2024  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2025  auto operand = resMatrixType.getOperand();
2026  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2027 
2028  if (!srcMemrefType.isLastDimUnitStride())
2029  return emitError(
2030  "expected source memref most minor dim must have unit stride");
2031 
2032  if (operand != "AOp" && operand != "BOp" && operand != "COp")
2033  return emitError("only AOp, BOp and COp can be loaded");
2034 
2035  return success();
2036 }
2037 
2038 //===----------------------------------------------------------------------===//
2039 // GPU_SubgroupMmaStoreMatrixOp
2040 //===----------------------------------------------------------------------===//
2041 
2042 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2043  auto srcType = getSrc().getType();
2044  auto dstType = getDstMemref().getType();
2045  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2046  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2047 
2048  if (!dstMemrefType.isLastDimUnitStride())
2049  return emitError(
2050  "expected destination memref most minor dim must have unit stride");
2051 
2052  if (srcMatrixType.getOperand() != "COp")
2053  return emitError(
2054  "expected the operand matrix being stored to have 'COp' operand type");
2055 
2056  return success();
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // GPU_SubgroupMmaComputeOp
2061 //===----------------------------------------------------------------------===//
2062 
2063 LogicalResult SubgroupMmaComputeOp::verify() {
2064  enum OperandMap { A, B, C };
2065  SmallVector<MMAMatrixType, 3> opTypes;
2066  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2067  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2068  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2069 
2070  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2071  opTypes[C].getOperand() != "COp")
2072  return emitError("operands must be in the order AOp, BOp, COp");
2073 
2074  ArrayRef<int64_t> aShape, bShape, cShape;
2075  aShape = opTypes[A].getShape();
2076  bShape = opTypes[B].getShape();
2077  cShape = opTypes[C].getShape();
2078 
2079  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2080  bShape[1] != cShape[1])
2081  return emitError("operand shapes do not satisfy matmul constraints");
2082 
2083  return success();
2084 }
2085 
2086 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2087  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2088  return memref::foldMemRefCast(*this);
2089 }
2090 
2091 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2092  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2093  return memref::foldMemRefCast(*this);
2094 }
2095 
2096 //===----------------------------------------------------------------------===//
2097 // GPU_WaitOp
2098 //===----------------------------------------------------------------------===//
2099 
2100 namespace {
2101 
2102 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2103 /// %t = gpu.wait async [] // No async dependencies.
2104 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2105 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2106 public:
2107  using OpRewritePattern::OpRewritePattern;
2108 
2109  LogicalResult matchAndRewrite(WaitOp op,
2110  PatternRewriter &rewriter) const final {
2111  auto predicate = [](Value value) {
2112  auto waitOp = value.getDefiningOp<WaitOp>();
2113  return waitOp && waitOp->getNumOperands() == 0;
2114  };
2115  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2116  return failure();
2117  SmallVector<Value> validOperands;
2118  for (Value operand : op->getOperands()) {
2119  if (predicate(operand))
2120  continue;
2121  validOperands.push_back(operand);
2122  }
2123  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2124  return success();
2125  }
2126 };
2127 
2128 /// Simplify trivial gpu.wait ops for the following patterns.
2129 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2130 /// dependencies).
2131 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2132 /// %t0.
2133 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2134 /// dependencies nor return any token.
2135 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2136 public:
2137  using OpRewritePattern::OpRewritePattern;
2138 
2139  LogicalResult matchAndRewrite(WaitOp op,
2140  PatternRewriter &rewriter) const final {
2141  // Erase gpu.wait ops that neither have any async dependencies nor return
2142  // any async token.
2143  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2144  rewriter.eraseOp(op);
2145  return success();
2146  }
2147  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2148  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2149  op.getAsyncToken()) {
2150  rewriter.replaceOp(op, op.getAsyncDependencies());
2151  return success();
2152  }
2153  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2154  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2155  rewriter.eraseOp(op);
2156  return success();
2157  }
2158  return failure();
2159  }
2160 };
2161 
2162 } // end anonymous namespace
2163 
2164 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2165  MLIRContext *context) {
2166  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2167 }
2168 
2169 //===----------------------------------------------------------------------===//
2170 // GPU_AllocOp
2171 //===----------------------------------------------------------------------===//
2172 
2173 LogicalResult AllocOp::verify() {
2174  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2175 
2176  if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2177  return emitOpError("dimension operand count does not equal memref "
2178  "dynamic dimension count");
2179 
2180  unsigned numSymbols = 0;
2181  if (!memRefType.getLayout().isIdentity())
2182  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2183  if (getSymbolOperands().size() != numSymbols) {
2184  return emitOpError(
2185  "symbol operand count does not equal memref symbol count");
2186  }
2187 
2188  return success();
2189 }
2190 
2191 namespace {
2192 
2193 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2194 /// `memref::AllocOp`.
2195 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2196  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2197 
2198  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2199  PatternRewriter &rewriter) const override {
2200  std::optional<int64_t> index = dimOp.getConstantIndex();
2201  if (!index)
2202  return failure();
2203 
2204  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2205  if (!memrefType || index.value() >= memrefType.getRank() ||
2206  !memrefType.isDynamicDim(index.value()))
2207  return failure();
2208 
2209  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2210  if (!alloc)
2211  return failure();
2212 
2213  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2214  memrefType.getDynamicDimIndex(index.value()));
2215  rewriter.replaceOp(dimOp, substituteOp);
2216  return success();
2217  }
2218 };
2219 
2220 } // namespace
2221 
2222 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2223  MLIRContext *context) {
2224  results.add<SimplifyDimOfAllocOp>(context);
2225 }
2226 
2227 //===----------------------------------------------------------------------===//
2228 // GPU object attribute
2229 //===----------------------------------------------------------------------===//
2230 
2231 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2232  Attribute target, CompilationTarget format,
2233  StringAttr object, DictionaryAttr properties,
2234  KernelTableAttr kernels) {
2235  if (!target)
2236  return emitError() << "the target attribute cannot be null";
2237  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2238  return success();
2239  return emitError() << "the target attribute must implement or promise the "
2240  "`gpu::TargetAttrInterface`";
2241 }
2242 
2243 namespace {
2244 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2245  StringAttr &object) {
2246  std::optional<CompilationTarget> formatResult;
2247  StringRef enumKeyword;
2248  auto loc = odsParser.getCurrentLocation();
2249  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2250  formatResult = CompilationTarget::Fatbin;
2251  if (!formatResult &&
2252  (formatResult =
2253  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2254  odsParser.parseEqual())
2255  return odsParser.emitError(loc, "expected an equal sign");
2256  if (!formatResult)
2257  return odsParser.emitError(loc, "expected keyword for GPU object format");
2258  FailureOr<StringAttr> objectResult =
2259  FieldParser<StringAttr>::parse(odsParser);
2260  if (failed(objectResult))
2261  return odsParser.emitError(odsParser.getCurrentLocation(),
2262  "failed to parse GPU_ObjectAttr parameter "
2263  "'object' which is to be a `StringAttr`");
2264  format = *formatResult;
2265  object = *objectResult;
2266  return success();
2267 }
2268 
2269 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2270  StringAttr object) {
2271  if (format != CompilationTarget::Fatbin)
2272  odsParser << stringifyEnum(format) << " = ";
2273  odsParser << object;
2274 }
2275 } // namespace
2276 
2277 //===----------------------------------------------------------------------===//
2278 // GPU select object attribute
2279 //===----------------------------------------------------------------------===//
2280 
2281 LogicalResult
2282 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2283  Attribute target) {
2284  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2285  if (target) {
2286  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2287  if (intAttr.getInt() < 0) {
2288  return emitError() << "the object index must be positive";
2289  }
2290  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2291  return emitError()
2292  << "the target attribute must be a GPU Target attribute";
2293  }
2294  }
2295  return success();
2296 }
2297 
2298 //===----------------------------------------------------------------------===//
2299 // DynamicSharedMemoryOp
2300 //===----------------------------------------------------------------------===//
2301 
2302 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2303  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2304  return emitOpError() << "must be inside an op with symbol table";
2305 
2306  MemRefType memrefType = getResultMemref().getType();
2307  // Check address space
2308  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2309  return emitOpError() << "address space must be "
2310  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2311  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2312  }
2313  if (memrefType.hasStaticShape()) {
2314  return emitOpError() << "result memref type must be memref<?xi8, "
2315  "#gpu.address_space<workgroup>>";
2316  }
2317  return success();
2318 }
2319 
2320 //===----------------------------------------------------------------------===//
2321 // GPU WarpExecuteOnLane0Op
2322 //===----------------------------------------------------------------------===//
2323 
2324 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2325  p << "(" << getLaneid() << ")";
2326 
2327  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2328  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2329  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2330 
2331  if (!getArgs().empty())
2332  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2333  if (!getResults().empty())
2334  p << " -> (" << getResults().getTypes() << ')';
2335  p << " ";
2336  p.printRegion(getRegion(),
2337  /*printEntryBlockArgs=*/true,
2338  /*printBlockTerminators=*/!getResults().empty());
2339  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2340 }
2341 
2342 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2343  OperationState &result) {
2344  // Create the region.
2345  result.regions.reserve(1);
2346  Region *warpRegion = result.addRegion();
2347 
2348  auto &builder = parser.getBuilder();
2349  OpAsmParser::UnresolvedOperand laneId;
2350 
2351  // Parse predicate operand.
2352  if (parser.parseLParen() ||
2353  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2354  parser.parseRParen())
2355  return failure();
2356 
2357  int64_t warpSize;
2358  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2359  parser.parseRSquare())
2360  return failure();
2361  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2362  builder.getContext())),
2363  builder.getI64IntegerAttr(warpSize));
2364 
2365  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2366  return failure();
2367 
2368  llvm::SMLoc inputsOperandsLoc;
2369  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2370  SmallVector<Type> inputTypes;
2371  if (succeeded(parser.parseOptionalKeyword("args"))) {
2372  if (parser.parseLParen())
2373  return failure();
2374 
2375  inputsOperandsLoc = parser.getCurrentLocation();
2376  if (parser.parseOperandList(inputsOperands) ||
2377  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2378  return failure();
2379  }
2380  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2381  result.operands))
2382  return failure();
2383 
2384  // Parse optional results type list.
2385  if (parser.parseOptionalArrowTypeList(result.types))
2386  return failure();
2387  // Parse the region.
2388  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2389  /*argTypes=*/{}))
2390  return failure();
2391  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2392 
2393  // Parse the optional attribute list.
2394  if (parser.parseOptionalAttrDict(result.attributes))
2395  return failure();
2396  return success();
2397 }
2398 
2399 void WarpExecuteOnLane0Op::getSuccessorRegions(
2400  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2401  if (!point.isParent()) {
2402  regions.push_back(RegionSuccessor(getResults()));
2403  return;
2404  }
2405 
2406  // The warp region is always executed
2407  regions.push_back(RegionSuccessor(&getWarpRegion()));
2408 }
2409 
2410 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2411  TypeRange resultTypes, Value laneId,
2412  int64_t warpSize) {
2413  build(builder, result, resultTypes, laneId, warpSize,
2414  /*operands=*/{}, /*argTypes=*/{});
2415 }
2416 
2417 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2418  TypeRange resultTypes, Value laneId,
2419  int64_t warpSize, ValueRange args,
2420  TypeRange blockArgTypes) {
2421  result.addOperands(laneId);
2422  result.addAttribute(getAttributeNames()[0],
2423  builder.getI64IntegerAttr(warpSize));
2424  result.addTypes(resultTypes);
2425  result.addOperands(args);
2426  assert(args.size() == blockArgTypes.size());
2427  OpBuilder::InsertionGuard guard(builder);
2428  Region *warpRegion = result.addRegion();
2429  Block *block = builder.createBlock(warpRegion);
2430  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2431  block->addArgument(type, arg.getLoc());
2432 }
2433 
2434 /// Helper check if the distributed vector type is consistent with the expanded
2435 /// type and distributed size.
2436 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2437  int64_t warpSize, Operation *op) {
2438  // If the types matches there is no distribution.
2439  if (expanded == distributed)
2440  return success();
2441  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2442  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2443  if (!expandedVecType || !distributedVecType)
2444  return op->emitOpError("expected vector type for distributed operands.");
2445  if (expandedVecType.getRank() != distributedVecType.getRank() ||
2446  expandedVecType.getElementType() != distributedVecType.getElementType())
2447  return op->emitOpError(
2448  "expected distributed vectors to have same rank and element type.");
2449 
2450  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2451  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2452  int64_t eDim = expandedVecType.getDimSize(i);
2453  int64_t dDim = distributedVecType.getDimSize(i);
2454  if (eDim == dDim)
2455  continue;
2456  if (eDim % dDim != 0)
2457  return op->emitOpError()
2458  << "expected expanded vector dimension #" << i << " (" << eDim
2459  << ") to be a multipler of the distributed vector dimension ("
2460  << dDim << ")";
2461  scales[i] = eDim / dDim;
2462  }
2463  if (llvm::product_of(scales) != warpSize)
2464  return op->emitOpError()
2465  << "incompatible distribution dimensions from " << expandedVecType
2466  << " to " << distributedVecType << " with warp size = " << warpSize;
2467 
2468  return success();
2469 }
2470 
2471 LogicalResult WarpExecuteOnLane0Op::verify() {
2472  if (getArgs().size() != getWarpRegion().getNumArguments())
2473  return emitOpError(
2474  "expected same number op arguments and block arguments.");
2475  gpu::YieldOp yield = getTerminator();
2476  if (yield.getNumOperands() != getNumResults())
2477  return emitOpError(
2478  "expected same number of yield operands and return values.");
2479  int64_t warpSize = getWarpSize();
2480  for (auto [regionArg, arg] :
2481  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2482  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2483  warpSize, getOperation())))
2484  return failure();
2485  }
2486  for (auto [yieldOperand, result] :
2487  llvm::zip_equal(yield.getOperands(), getResults())) {
2488  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2489  warpSize, getOperation())))
2490  return failure();
2491  }
2492  return success();
2493 }
2494 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2495  return succeeded(
2496  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2497 }
2498 
2499 gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2500  return cast<gpu::YieldOp>(getBody()->getTerminator());
2501 }
2502 
2503 //===----------------------------------------------------------------------===//
2504 // GPU_SubgroupBroadcastOp
2505 //===----------------------------------------------------------------------===//
2506 
2507 void gpu::SubgroupBroadcastOp::inferResultRanges(
2508  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2509  setResultRange(getResult(), argRanges.front());
2510 }
2511 
2512 Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2513  switch (getBroadcastType()) {
2514  case BroadcastType::first_active_lane:
2515  // Cannot speculate first_lane broadcast, because speculating it across
2516  // control flow can change the active lanes.
2517  return Speculation::NotSpeculatable;
2518  case BroadcastType::specific_lane:
2519  // Speculation should be safe as long as we inside structured control flow.
2520  return Speculation::Speculatable;
2521  }
2522 }
2523 
2524 LogicalResult gpu::SubgroupBroadcastOp::verify() {
2525  switch (getBroadcastType()) {
2526  case BroadcastType::first_active_lane:
2527  if (getLane())
2528  return emitOpError()
2529  << "lane can only be specified for `specific_lane` broadcast";
2530  return success();
2531  case BroadcastType::specific_lane:
2532  if (!getLane())
2533  return emitOpError()
2534  << "lane must be specified for `specific_lane` broadcast";
2535  return success();
2536  }
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // GPU KernelMetadataAttr
2541 //===----------------------------------------------------------------------===//
2542 
2543 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2544  DictionaryAttr metadata) {
2545  assert(kernel && "invalid kernel");
2546  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2547  kernel.getAllArgAttrs(), metadata);
2548 }
2549 
2550 KernelMetadataAttr
2551 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2552  FunctionOpInterface kernel,
2553  DictionaryAttr metadata) {
2554  assert(kernel && "invalid kernel");
2555  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2556  kernel.getAllArgAttrs(), metadata);
2557 }
2558 
2559 KernelMetadataAttr
2560 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2561  if (attrs.empty())
2562  return *this;
2563  NamedAttrList attrList;
2564  if (DictionaryAttr dict = getMetadata())
2565  attrList.append(dict);
2566  attrList.append(attrs);
2567  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2568  attrList.getDictionary(getContext()));
2569 }
2570 
2571 LogicalResult
2572 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2573  StringAttr name, Type functionType,
2574  ArrayAttr argAttrs, DictionaryAttr metadata) {
2575  if (name.empty())
2576  return emitError() << "the kernel name can't be empty";
2577  if (argAttrs) {
2578  if (llvm::any_of(argAttrs, [](Attribute attr) {
2579  return !llvm::isa<DictionaryAttr>(attr);
2580  }))
2581  return emitError()
2582  << "all attributes in the array must be a dictionary attribute";
2583  }
2584  return success();
2585 }
2586 
2587 //===----------------------------------------------------------------------===//
2588 // GPU KernelTableAttr
2589 //===----------------------------------------------------------------------===//
2590 
2591 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2592  ArrayRef<KernelMetadataAttr> kernels,
2593  bool isSorted) {
2594  // Note that `is_sorted` is always only invoked once even with assertions ON.
2595  assert((!isSorted || llvm::is_sorted(kernels)) &&
2596  "expected a sorted kernel array");
2597  // Immediately return the attribute if the array is sorted.
2598  if (isSorted || llvm::is_sorted(kernels))
2599  return Base::get(context, kernels);
2600  // Sort the array.
2601  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2602  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2603  return Base::get(context, kernelsTmp);
2604 }
2605 
2606 KernelTableAttr KernelTableAttr::getChecked(
2607  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2608  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2609  // Note that `is_sorted` is always only invoked once even with assertions ON.
2610  assert((!isSorted || llvm::is_sorted(kernels)) &&
2611  "expected a sorted kernel array");
2612  // Immediately return the attribute if the array is sorted.
2613  if (isSorted || llvm::is_sorted(kernels))
2614  return Base::getChecked(emitError, context, kernels);
2615  // Sort the array.
2616  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2617  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2618  return Base::getChecked(emitError, context, kernelsTmp);
2619 }
2620 
2621 LogicalResult
2622 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2623  ArrayRef<KernelMetadataAttr> kernels) {
2624  if (kernels.size() < 2)
2625  return success();
2626  // Check that the kernels are uniquely named.
2627  if (std::adjacent_find(kernels.begin(), kernels.end(),
2628  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2629  return l.getName() == r.getName();
2630  }) != kernels.end()) {
2631  return emitError() << "expected all kernels to be uniquely named";
2632  }
2633  return success();
2634 }
2635 
2636 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2637  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2638  return found ? *iterator : KernelMetadataAttr();
2639 }
2640 
2641 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2642  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2643  return found ? *iterator : KernelMetadataAttr();
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // GPU target options
2648 //===----------------------------------------------------------------------===//
2649 
2650 TargetOptions::TargetOptions(
2651  StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2652  StringRef cmdOptions, StringRef elfSection,
2653  CompilationTarget compilationTarget,
2654  function_ref<SymbolTable *()> getSymbolTableCallback,
2655  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2656  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2657  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2658  function_ref<void(StringRef)> isaCallback)
2659  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2660  cmdOptions, elfSection, compilationTarget,
2661  getSymbolTableCallback, initialLlvmIRCallback,
2662  linkedLlvmIRCallback, optimizedLlvmIRCallback,
2663  isaCallback) {}
2664 
2665 TargetOptions::TargetOptions(
2666  TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2667  StringRef cmdOptions, StringRef elfSection,
2668  CompilationTarget compilationTarget,
2669  function_ref<SymbolTable *()> getSymbolTableCallback,
2670  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2671  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2672  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2673  function_ref<void(StringRef)> isaCallback)
2674  : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2675  cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2676  compilationTarget(compilationTarget),
2677  getSymbolTableCallback(getSymbolTableCallback),
2678  initialLlvmIRCallback(initialLlvmIRCallback),
2679  linkedLlvmIRCallback(linkedLlvmIRCallback),
2680  optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2681  isaCallback(isaCallback), typeID(typeID) {}
2682 
2683 TypeID TargetOptions::getTypeID() const { return typeID; }
2684 
2685 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2686 
2687 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2688  return librariesToLink;
2689 }
2690 
2691 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2692 
2693 StringRef TargetOptions::getELFSection() const { return elfSection; }
2694 
2695 SymbolTable *TargetOptions::getSymbolTable() const {
2696  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2697 }
2698 
2699 function_ref<void(llvm::Module &)>
2700 TargetOptions::getInitialLlvmIRCallback() const {
2701  return initialLlvmIRCallback;
2702 }
2703 
2704 function_ref<void(llvm::Module &)>
2705 TargetOptions::getLinkedLlvmIRCallback() const {
2706  return linkedLlvmIRCallback;
2707 }
2708 
2709 function_ref<void(llvm::Module &)>
2710 TargetOptions::getOptimizedLlvmIRCallback() const {
2711  return optimizedLlvmIRCallback;
2712 }
2713 
2714 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2715  return isaCallback;
2716 }
2717 
2718 CompilationTarget TargetOptions::getCompilationTarget() const {
2719  return compilationTarget;
2720 }
2721 
2722 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2723  return CompilationTarget::Fatbin;
2724 }
2725 
2726 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2727 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2728  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2729  llvm::StringSaver stringSaver(options.first);
2730  StringRef opts = cmdOptions;
2731  // For a correct tokenization of the command line options `opts` must be
2732  // unquoted, otherwise the tokenization function returns a single string: the
2733  // unquoted `cmdOptions` -which is not the desired behavior.
2734  // Remove any quotes if they are at the beginning and end of the string:
2735  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2736  opts.consume_front("\""), opts.consume_back("\"");
2737  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2738  opts.consume_front("'"), opts.consume_back("'");
2739 #ifdef _WIN32
2740  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2741  /*MarkEOLs=*/false);
2742 #else
2743  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2744  /*MarkEOLs=*/false);
2745 #endif // _WIN32
2746  return options;
2747 }
2748 
2749 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2752 }
2753 
2754 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2756  size_t startPos = cmdOptions.find(startsWith);
2757  if (startPos == std::string::npos)
2758  return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2759 
2760  auto tokenized =
2761  tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2762  cmdOptions.resize(startPos);
2763  return tokenized;
2764 }
2765 
2767 
2768 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2769 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2770 
2771 #define GET_ATTRDEF_CLASSES
2772 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2773 
2774 #define GET_OP_CLASSES
2775 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2776 
2777 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:491
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:663
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values, ArrayAttr attributes={})
Definition: GPUDialect.cpp:538
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:507
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:997
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:641
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:291
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:381
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:676
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:528
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:587
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:932
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:561
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
IndexType getIndexType()
Definition: Builders.cpp:51
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:104
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:94
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:256
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isIndex() const
Definition: Types.cpp:54
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:202
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:187
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:206
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:210
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:217
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:208
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:193
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:200
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:740
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39