MLIR  22.0.0git
GPUDialect.cpp
Go to the documentation of this file.
1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU kernel-related dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Diagnostics.h"
26 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/SymbolTable.h"
30 #include "mlir/IR/TypeUtilities.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/TypeSwitch.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/ErrorHandling.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/InterleavedRange.h"
41 #include "llvm/Support/StringSaver.h"
42 #include <cassert>
43 #include <numeric>
44 
45 using namespace mlir;
46 using namespace mlir::gpu;
47 
48 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // GPU Device Mapping Attributes
52 //===----------------------------------------------------------------------===//
53 
54 int64_t GPUBlockMappingAttr::getMappingId() const {
55  return static_cast<int64_t>(getBlock());
56 }
57 
58 bool GPUBlockMappingAttr::isLinearMapping() const {
59  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
60 }
61 
62 int64_t GPUBlockMappingAttr::getRelativeIndex() const {
63  return isLinearMapping()
64  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
65  : getMappingId();
66 }
67 
68 int64_t GPUWarpgroupMappingAttr::getMappingId() const {
69  return static_cast<int64_t>(getWarpgroup());
70 }
71 
72 bool GPUWarpgroupMappingAttr::isLinearMapping() const {
73  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
74 }
75 
76 int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
77  return isLinearMapping()
78  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
79  : getMappingId();
80 }
81 
82 int64_t GPUWarpMappingAttr::getMappingId() const {
83  return static_cast<int64_t>(getWarp());
84 }
85 
86 bool GPUWarpMappingAttr::isLinearMapping() const {
87  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
88 }
89 
90 int64_t GPUWarpMappingAttr::getRelativeIndex() const {
91  return isLinearMapping()
92  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
93  : getMappingId();
94 }
95 
96 int64_t GPUThreadMappingAttr::getMappingId() const {
97  return static_cast<int64_t>(getThread());
98 }
99 
100 bool GPUThreadMappingAttr::isLinearMapping() const {
101  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
102 }
103 
104 int64_t GPUThreadMappingAttr::getRelativeIndex() const {
105  return isLinearMapping()
106  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
107  : getMappingId();
108 }
109 
110 int64_t GPULaneMappingAttr::getMappingId() const {
111  return static_cast<int64_t>(getLane());
112 }
113 
114 bool GPULaneMappingAttr::isLinearMapping() const {
115  return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
116 }
117 
118 int64_t GPULaneMappingAttr::getRelativeIndex() const {
119  return isLinearMapping()
120  ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
121  : getMappingId();
122 }
123 
124 int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125 
126 /// 8 4 0
127 /// Example mask : 0 0 0 1 1 0 1 0 0
128 ///
129 /// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130 /// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131 ///
132 /// Example mask : 0 0 0 1 1 0 1 0 0
133 /// Example filter: 0 0 0 0 1 1 1 1 1
134 /// Intersection : 0 0 0 0 1 0 1 0 0
135 /// PopCnt : 2
136 Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137  OpBuilder &b, Value physicalLinearMappingId) const {
138  Location loc = physicalLinearMappingId.getLoc();
139  Value mask =
140  arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
141  Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
142  Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
143  filter = arith::SubIOp::create(b, loc, filter, one);
144  Value filteredId = arith::AndIOp::create(b, loc, mask, filter);
145  return math::CtPopOp::create(b, loc, filteredId);
146 }
147 
148 /// 8 4 0
149 /// Example mask : 0 0 0 1 1 0 1 0 0
150 ///
151 /// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
152 /// Logical id for e.g. 5 (2) constructs filter (1 << 5).
153 ///
154 /// Example mask : 0 0 0 1 1 0 1 0 0
155 /// Example filter: 0 0 0 1 0 0 0 0 0
156 /// Intersection : 0 0 0 1 0 0 0 0 0
157 /// Cmp : 1
158 Value GPUMappingMaskAttr::createIsActiveIdPredicate(
159  OpBuilder &b, Value physicalLinearMappingId) const {
160  Location loc = physicalLinearMappingId.getLoc();
161  Value mask =
162  arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(getMask()));
163  Value one = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1));
164  Value filter = arith::ShLIOp::create(b, loc, one, physicalLinearMappingId);
165  Value filtered = arith::AndIOp::create(b, loc, mask, filter);
166  Value zero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0));
167  return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, filtered,
168  zero);
169 }
170 
171 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
172  return static_cast<int64_t>(getAddressSpace());
173 }
174 
175 bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
176  llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
177 }
178 
179 int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
180  llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // MMAMatrixType
185 //===----------------------------------------------------------------------===//
186 
188  StringRef operand) {
189  return Base::get(elementType.getContext(), shape, elementType, operand);
190 }
191 
194  ArrayRef<int64_t> shape, Type elementType,
195  StringRef operand) {
196  return Base::getChecked(emitError, elementType.getContext(), shape,
197  elementType, operand);
198 }
199 
200 unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
201 
203  return getImpl()->getShape();
204 }
205 
206 Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
207 
208 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
209 
211  return elementType.isF16() || elementType.isF32() ||
212  elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
213  elementType.isInteger(32);
214 }
215 
216 LogicalResult
218  ArrayRef<int64_t> shape, Type elementType,
219  StringRef operand) {
220  if (operand != "AOp" && operand != "BOp" && operand != "COp")
221  return emitError() << "operand expected to be one of AOp, BOp or COp";
222 
223  if (shape.size() != 2)
224  return emitError() << "MMAMatrixType must have exactly two dimensions";
225 
226  if (!MMAMatrixType::isValidElementType(elementType))
227  return emitError()
228  << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
229 
230  return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // GPUDialect
235 //===----------------------------------------------------------------------===//
236 
237 bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
238  if (!memorySpace)
239  return false;
240  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
241  return gpuAttr.getValue() == getWorkgroupAddressSpace();
242  return false;
243 }
244 
245 bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
246  Attribute memorySpace = type.getMemorySpace();
247  return isWorkgroupMemoryAddressSpace(memorySpace);
248 }
249 
250 bool GPUDialect::isKernel(Operation *op) {
251  UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
252  return static_cast<bool>(isKernelAttr);
253 }
254 
255 namespace {
256 /// This class defines the interface for handling inlining with gpu
257 /// operations.
258 struct GPUInlinerInterface : public DialectInlinerInterface {
260 
261  /// All gpu dialect ops can be inlined.
262  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
263  return true;
264  }
265 };
266 } // namespace
267 
268 void GPUDialect::initialize() {
269  addTypes<AsyncTokenType>();
270  addTypes<MMAMatrixType>();
271  addTypes<SparseDnTensorHandleType>();
272  addTypes<SparseSpMatHandleType>();
273  addTypes<SparseSpGEMMOpHandleType>();
274  addOperations<
275 #define GET_OP_LIST
276 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
277  >();
278  addAttributes<
279 #define GET_ATTRDEF_LIST
280 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
281  >();
282  addInterfaces<GPUInlinerInterface>();
283  declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
284  TerminatorOp>();
285  declarePromisedInterfaces<
286  ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
287  ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
288  SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
289 }
290 
292  switch (kind) {
294  return "sparse.dntensor_handle";
296  return "sparse.spmat_handle";
298  return "sparse.spgemmop_handle";
299  }
300  llvm_unreachable("unknown sparse handle kind");
301  return "";
302 }
303 
305  // Parse the main keyword for the type.
306  StringRef keyword;
307  if (parser.parseKeyword(&keyword))
308  return Type();
309  MLIRContext *context = getContext();
310 
311  // Handle 'async token' types.
312  if (keyword == "async.token")
313  return AsyncTokenType::get(context);
314 
315  if (keyword == "mma_matrix") {
316  SMLoc beginLoc = parser.getNameLoc();
317 
318  // Parse '<'.
319  if (parser.parseLess())
320  return nullptr;
321 
322  // Parse the size and elementType.
323  SmallVector<int64_t> shape;
324  Type elementType;
325  if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
326  parser.parseType(elementType))
327  return nullptr;
328 
329  // Parse ','
330  if (parser.parseComma())
331  return nullptr;
332 
333  // Parse operand.
334  std::string operand;
335  if (failed(parser.parseOptionalString(&operand)))
336  return nullptr;
337 
338  // Parse '>'.
339  if (parser.parseGreater())
340  return nullptr;
341 
343  parser.getEncodedSourceLoc(beginLoc)),
344  shape, elementType, operand);
345  }
346 
348  return SparseDnTensorHandleType::get(context);
350  return SparseSpMatHandleType::get(context);
352  return SparseSpGEMMOpHandleType::get(context);
353 
354  parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
355  return Type();
356 }
357 // TODO: print refined type here. Notice that should be corresponding to the
358 // parser
359 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
360  TypeSwitch<Type>(type)
361  .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
362  .Case<SparseDnTensorHandleType>([&](Type) {
364  })
365  .Case<SparseSpMatHandleType>(
367  .Case<SparseSpGEMMOpHandleType>([&](Type) {
369  })
370  .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
371  os << "mma_matrix<";
372  auto shape = fragTy.getShape();
373  for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
374  os << *dim << 'x';
375  os << shape.back() << 'x' << fragTy.getElementType();
376  os << ", \"" << fragTy.getOperand() << "\"" << '>';
377  })
378  .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
379 }
380 
381 static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
382  NamedAttribute attr) {
383  auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
384  if (!array)
385  return op->emitOpError(Twine(attr.getName()) +
386  " must be a dense i32 array");
387  if (array.size() != 3)
388  return op->emitOpError(Twine(attr.getName()) +
389  " must contain exactly 3 elements");
390  return success();
391 }
392 
393 LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
394  NamedAttribute attr) {
395  if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
396  return verifyKnownLaunchSizeAttr(op, attr);
397  if (attr.getName() == getKnownGridSizeAttrHelper().getName())
398  return verifyKnownLaunchSizeAttr(op, attr);
399  if (!llvm::isa<UnitAttr>(attr.getValue()) ||
400  attr.getName() != getContainerModuleAttrName())
401  return success();
402 
403  auto module = dyn_cast<ModuleOp>(op);
404  if (!module)
405  return op->emitError("expected '")
406  << getContainerModuleAttrName() << "' attribute to be attached to '"
407  << ModuleOp::getOperationName() << '\'';
408 
409  auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
410  // Ignore launches that are nested more or less deep than functions in the
411  // module we are currently checking.
412  if (!launchOp->getParentOp() ||
413  launchOp->getParentOp()->getParentOp() != module)
414  return success();
415 
416  // Ignore launch ops with missing attributes here. The errors will be
417  // reported by the verifiers of those ops.
418  if (!launchOp->getAttrOfType<SymbolRefAttr>(
419  LaunchFuncOp::getKernelAttrName(launchOp->getName())))
420  return success();
421 
422  // Check that `launch_func` refers to a well-formed GPU kernel container.
423  StringAttr kernelContainerName = launchOp.getKernelModuleName();
424  Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
425  if (!kernelContainer)
426  return launchOp.emitOpError()
427  << "kernel container '" << kernelContainerName.getValue()
428  << "' is undefined";
429 
430  // If the container is a GPU binary op return success.
431  if (isa<BinaryOp>(kernelContainer))
432  return success();
433 
434  auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
435  if (!kernelModule)
436  return launchOp.emitOpError()
437  << "kernel module '" << kernelContainerName.getValue()
438  << "' is undefined";
439 
440  // Check that `launch_func` refers to a well-formed kernel function.
441  Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
442  if (!kernelFunc)
443  return launchOp.emitOpError("kernel function '")
444  << launchOp.getKernel() << "' is undefined";
445  auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
446  if (!kernelConvertedFunction) {
447  InFlightDiagnostic diag = launchOp.emitOpError()
448  << "referenced kernel '" << launchOp.getKernel()
449  << "' is not a function";
450  diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
451  return diag;
452  }
453 
454  if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
455  GPUDialect::getKernelFuncAttrName()))
456  return launchOp.emitOpError("kernel function is missing the '")
457  << GPUDialect::getKernelFuncAttrName() << "' attribute";
458 
459  // TODO: If the kernel isn't a GPU function (which happens during separate
460  // compilation), do not check type correspondence as it would require the
461  // verifier to be aware of the type conversion.
462  auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
463  if (!kernelGPUFunction)
464  return success();
465 
466  unsigned actualNumArguments = launchOp.getNumKernelOperands();
467  unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
468  if (expectedNumArguments != actualNumArguments)
469  return launchOp.emitOpError("got ")
470  << actualNumArguments << " kernel operands but expected "
471  << expectedNumArguments;
472 
473  auto functionType = kernelGPUFunction.getFunctionType();
474  for (unsigned i = 0; i < expectedNumArguments; ++i) {
475  if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
476  return launchOp.emitOpError("type of function argument ")
477  << i << " does not match";
478  }
479  }
480 
481  return success();
482  });
483 
484  return walkResult.wasInterrupted() ? failure() : success();
485 }
486 
487 /// Parses an optional list of async operands with an optional leading keyword.
488 /// (`async`)? (`[` ssa-id-list `]`)?
489 ///
490 /// This method is used by the tablegen assembly format for async ops as well.
491 static ParseResult parseAsyncDependencies(
492  OpAsmParser &parser, Type &asyncTokenType,
494  auto loc = parser.getCurrentLocation();
495  if (succeeded(parser.parseOptionalKeyword("async"))) {
496  if (parser.getNumResults() == 0)
497  return parser.emitError(loc, "needs to be named when marked 'async'");
498  asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
499  }
500  return parser.parseOperandList(asyncDependencies,
502 }
503 
504 /// Prints optional async dependencies with its leading keyword.
505 /// (`async`)? (`[` ssa-id-list `]`)?
506 // Used by the tablegen assembly format for several async ops.
508  Type asyncTokenType,
509  OperandRange asyncDependencies) {
510  if (asyncTokenType)
511  printer << "async";
512  if (asyncDependencies.empty())
513  return;
514  if (asyncTokenType)
515  printer << ' ';
516  printer << llvm::interleaved_array(asyncDependencies);
517 }
518 
519 // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
520 /// Parses a GPU function memory attribution.
521 ///
522 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
523 /// (`private` `(` ssa-id-and-type-list `)`)?
524 ///
525 /// Note that this function parses only one of the two similar parts, with the
526 /// keyword provided as argument.
527 static ParseResult
528 parseAttributions(OpAsmParser &parser, StringRef keyword,
530  // If we could not parse the keyword, just assume empty list and succeed.
531  if (failed(parser.parseOptionalKeyword(keyword)))
532  return success();
533 
535  /*allowType=*/true);
536 }
537 
538 /// Prints a GPU function memory attribution.
539 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
540  ArrayRef<BlockArgument> values) {
541  if (values.empty())
542  return;
543 
544  auto printBlockArg = [](BlockArgument v) {
545  return llvm::formatv("{} : {}", v, v.getType());
546  };
547  p << ' ' << keyword << '('
548  << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')';
549 }
550 
551 /// Verifies a GPU function memory attribution.
552 static LogicalResult verifyAttributions(Operation *op,
553  ArrayRef<BlockArgument> attributions,
554  gpu::AddressSpace memorySpace) {
555  for (Value v : attributions) {
556  auto type = llvm::dyn_cast<MemRefType>(v.getType());
557  if (!type)
558  return op->emitOpError() << "expected memref type in attribution";
559 
560  // We can only verify the address space if it hasn't already been lowered
561  // from the AddressSpaceAttr to a target-specific numeric value.
562  auto addressSpace =
563  llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
564  if (!addressSpace)
565  continue;
566  if (addressSpace.getValue() != memorySpace)
567  return op->emitOpError()
568  << "expected memory space " << stringifyAddressSpace(memorySpace)
569  << " in attribution";
570  }
571  return success();
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // AllReduceOp
576 //===----------------------------------------------------------------------===//
577 
578 static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
579  Type resType) {
580  using Kind = gpu::AllReduceOperation;
581  if (llvm::is_contained(
582  {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
583  opName)) {
584  if (!isa<FloatType>(resType))
585  return failure();
586  }
587 
588  if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
589  Kind::AND, Kind::OR, Kind::XOR},
590  opName)) {
591  if (!isa<IntegerType>(resType))
592  return failure();
593  }
594 
595  return success();
596 }
597 
598 LogicalResult gpu::AllReduceOp::verifyRegions() {
599  if (getBody().empty() != getOp().has_value())
600  return emitError("expected either an op attribute or a non-empty body");
601  if (!getBody().empty()) {
602  if (getBody().getNumArguments() != 2)
603  return emitError("expected two region arguments");
604  for (auto argument : getBody().getArguments()) {
605  if (argument.getType() != getType())
606  return emitError("incorrect region argument type");
607  }
608  unsigned yieldCount = 0;
609  for (Block &block : getBody()) {
610  if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
611  if (yield.getNumOperands() != 1)
612  return emitError("expected one gpu.yield operand");
613  if (yield.getOperand(0).getType() != getType())
614  return emitError("incorrect gpu.yield type");
615  ++yieldCount;
616  }
617  }
618  if (yieldCount == 0)
619  return emitError("expected gpu.yield op in region");
620  } else {
621  gpu::AllReduceOperation opName = *getOp();
622  if (failed(verifyReduceOpAndType(opName, getType()))) {
623  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
624  << "` reduction operation is not compatible with type "
625  << getType();
626  }
627  }
628 
629  return success();
630 }
631 
633  auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
634  if (!launchOp)
635  return false;
636 
637  Region &body = launchOp.getBody();
638  assert(!body.empty() && "Invalid region");
639 
640  // Only convert ops in gpu::launch entry block for now.
641  return op->getBlock() == &body.front();
642 }
643 
644 OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
645  if (!getUniform() && canMakeGroupOpUniform(*this)) {
646  setUniform(true);
647  return getResult();
648  }
649 
650  return nullptr;
651 }
652 
653 // TODO: Support optional custom attributes (without dialect prefix).
654 static ParseResult parseAllReduceOperation(AsmParser &parser,
655  AllReduceOperationAttr &attr) {
656  StringRef enumStr;
657  if (!parser.parseOptionalKeyword(&enumStr)) {
658  std::optional<AllReduceOperation> op =
659  gpu::symbolizeAllReduceOperation(enumStr);
660  if (!op)
661  return parser.emitError(parser.getCurrentLocation(), "invalid op kind");
662  attr = AllReduceOperationAttr::get(parser.getContext(), *op);
663  }
664  return success();
665 }
666 
667 static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
668  AllReduceOperationAttr attr) {
669  if (attr)
670  attr.print(printer);
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // SubgroupReduceOp
675 //===----------------------------------------------------------------------===//
676 
677 LogicalResult gpu::SubgroupReduceOp::verify() {
678  Type elemType = getType();
679  if (auto vecTy = dyn_cast<VectorType>(elemType)) {
680  if (vecTy.isScalable())
681  return emitOpError() << "is not compatible with scalable vector types";
682 
683  elemType = vecTy.getElementType();
684  }
685 
686  gpu::AllReduceOperation opName = getOp();
687  if (failed(verifyReduceOpAndType(opName, elemType))) {
688  return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
689  << "` reduction operation is not compatible with type "
690  << getType();
691  }
692 
693  auto clusterSize = getClusterSize();
694  if (clusterSize) {
695  uint32_t size = *clusterSize;
696  if (!llvm::isPowerOf2_32(size)) {
697  return emitOpError() << "cluster size " << size
698  << " is not a power of two";
699  }
700  }
701 
702  uint32_t stride = getClusterStride();
703  if (stride != 1 && !clusterSize) {
704  return emitOpError() << "cluster stride can only be specified if cluster "
705  "size is specified";
706  }
707  if (!llvm::isPowerOf2_32(stride)) {
708  return emitOpError() << "cluster stride " << stride
709  << " is not a power of two";
710  }
711 
712  return success();
713 }
714 
715 OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
716  if (getClusterSize() == 1)
717  return getValue();
718 
719  if (!getUniform() && canMakeGroupOpUniform(*this)) {
720  setUniform(true);
721  return getResult();
722  }
723 
724  return nullptr;
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // AsyncOpInterface
729 //===----------------------------------------------------------------------===//
730 
732  op->insertOperands(0, {token});
733  if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
734  return;
735  auto attrName =
737  auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
738 
739  // Async dependencies is the only variadic operand.
740  if (!sizeAttr)
741  return;
742 
743  SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
744  ++sizes.front();
745  op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // LaunchOp
750 //===----------------------------------------------------------------------===//
751 
752 void LaunchOp::build(OpBuilder &builder, OperationState &result,
753  Value gridSizeX, Value gridSizeY, Value gridSizeZ,
754  Value getBlockSizeX, Value getBlockSizeY,
755  Value getBlockSizeZ, Value dynamicSharedMemorySize,
756  Type asyncTokenType, ValueRange asyncDependencies,
757  TypeRange workgroupAttributions,
758  TypeRange privateAttributions, Value clusterSizeX,
759  Value clusterSizeY, Value clusterSizeZ,
760  FlatSymbolRefAttr module, FlatSymbolRefAttr function) {
761  OpBuilder::InsertionGuard g(builder);
762 
763  // Add a WorkGroup attribution attribute. This attribute is required to
764  // identify private attributions in the list of block argguments.
765  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
766  builder.getI64IntegerAttr(workgroupAttributions.size()));
767 
768  // Add Op operands.
769  result.addOperands(asyncDependencies);
770  if (asyncTokenType)
771  result.types.push_back(builder.getType<AsyncTokenType>());
772 
773  // Add grid and block sizes as op operands, followed by the data operands.
774  result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
775  getBlockSizeY, getBlockSizeZ});
776  if (clusterSizeX)
777  result.addOperands(clusterSizeX);
778  if (clusterSizeY)
779  result.addOperands(clusterSizeY);
780  if (clusterSizeZ)
781  result.addOperands(clusterSizeZ);
782  if (dynamicSharedMemorySize)
783  result.addOperands(dynamicSharedMemorySize);
784 
785  // Add optional module and function attributes.
786  if (module)
787  result.addAttribute(getModuleAttrName(result.name), module);
788  if (function)
789  result.addAttribute(getFunctionAttrName(result.name), function);
790 
791  // Create a kernel body region with kNumConfigRegionAttributes + N memory
792  // attributions, where the first kNumConfigRegionAttributes arguments have
793  // `index` type and the rest have the same types as the data operands.
794  Region *kernelRegion = result.addRegion();
795  Block *body = builder.createBlock(kernelRegion);
796  // TODO: Allow passing in proper locations here.
797  for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
798  body->addArgument(builder.getIndexType(), result.location);
799  // Add WorkGroup & Private attributions to the region arguments.
800  for (Type argTy : workgroupAttributions)
801  body->addArgument(argTy, result.location);
802  for (Type argTy : privateAttributions)
803  body->addArgument(argTy, result.location);
804  // Fill OperandSegmentSize Attribute.
805  SmallVector<int32_t, 11> segmentSizes(11, 1);
806  segmentSizes.front() = asyncDependencies.size();
807  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
808  segmentSizes[7] = clusterSizeX ? 1 : 0;
809  segmentSizes[8] = clusterSizeY ? 1 : 0;
810  segmentSizes[9] = clusterSizeZ ? 1 : 0;
811  result.addAttribute(getOperandSegmentSizeAttr(),
812  builder.getDenseI32ArrayAttr(segmentSizes));
813 }
814 
815 KernelDim3 LaunchOp::getBlockIds() {
816  assert(!getBody().empty() && "LaunchOp body must not be empty.");
817  auto args = getBody().getArguments();
818  return KernelDim3{args[0], args[1], args[2]};
819 }
820 
821 KernelDim3 LaunchOp::getThreadIds() {
822  assert(!getBody().empty() && "LaunchOp body must not be empty.");
823  auto args = getBody().getArguments();
824  return KernelDim3{args[3], args[4], args[5]};
825 }
826 
827 KernelDim3 LaunchOp::getGridSize() {
828  assert(!getBody().empty() && "LaunchOp body must not be empty.");
829  auto args = getBody().getArguments();
830  return KernelDim3{args[6], args[7], args[8]};
831 }
832 
834  assert(!getBody().empty() && "LaunchOp body must not be empty.");
835  auto args = getBody().getArguments();
836  return KernelDim3{args[9], args[10], args[11]};
837 }
838 
839 std::optional<KernelDim3> LaunchOp::getClusterIds() {
840  assert(!getBody().empty() && "LaunchOp body must not be empty.");
841  if (!hasClusterSize())
842  return std::nullopt;
843  auto args = getBody().getArguments();
844  return KernelDim3{args[12], args[13], args[14]};
845 }
846 
847 std::optional<KernelDim3> LaunchOp::getClusterSize() {
848  assert(!getBody().empty() && "LaunchOp body must not be empty.");
849  if (!hasClusterSize())
850  return std::nullopt;
851  auto args = getBody().getArguments();
852  return KernelDim3{args[15], args[16], args[17]};
853 }
854 
855 KernelDim3 LaunchOp::getGridSizeOperandValues() {
856  auto operands = getOperands().drop_front(getAsyncDependencies().size());
857  return KernelDim3{operands[0], operands[1], operands[2]};
858 }
859 
860 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
861  auto operands = getOperands().drop_front(getAsyncDependencies().size());
862  return KernelDim3{operands[3], operands[4], operands[5]};
863 }
864 
865 std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
866  auto operands = getOperands().drop_front(getAsyncDependencies().size());
867  if (!hasClusterSize())
868  return std::nullopt;
869  return KernelDim3{operands[6], operands[7], operands[8]};
870 }
871 
872 LogicalResult LaunchOp::verify() {
873  if (!(hasClusterSize()) &&
874  (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
875  return emitOpError() << "cluster size must be all present";
876  return success();
877 }
878 
879 LogicalResult LaunchOp::verifyRegions() {
880  // Kernel launch takes kNumConfigOperands leading operands for grid/block
881  // sizes and transforms them into kNumConfigRegionAttributes region arguments
882  // for block/thread identifiers and grid/block sizes.
883  if (!getBody().empty()) {
884  if (getBody().getNumArguments() <
885  kNumConfigRegionAttributes + getNumWorkgroupAttributions())
886  return emitOpError("unexpected number of region arguments");
887  }
888 
889  // Verify Attributions Address Spaces.
890  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
891  GPUDialect::getWorkgroupAddressSpace())) ||
892  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
893  GPUDialect::getPrivateAddressSpace())))
894  return failure();
895 
896  // Block terminators without successors are expected to exit the kernel region
897  // and must be `gpu.terminator`.
898  for (Block &block : getBody()) {
899  if (block.empty())
900  continue;
901  if (block.back().getNumSuccessors() != 0)
902  continue;
903  if (!isa<gpu::TerminatorOp>(&block.back())) {
904  return block.back()
905  .emitError()
906  .append("expected '", gpu::TerminatorOp::getOperationName(),
907  "' or a terminator with successors")
908  .attachNote(getLoc())
909  .append("in '", LaunchOp::getOperationName(), "' body region");
910  }
911  }
912 
913  if (getNumResults() == 0 && getAsyncToken())
914  return emitOpError("needs to be named when async keyword is specified");
915 
916  return success();
917 }
918 
919 // Pretty-print the kernel grid/block size assignment as
920 // (%iter-x, %iter-y, %iter-z) in
921 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
922 // where %size-* and %iter-* will correspond to the body region arguments.
924  KernelDim3 operands, KernelDim3 ids) {
925  p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
926  p << size.x << " = " << operands.x << ", ";
927  p << size.y << " = " << operands.y << ", ";
928  p << size.z << " = " << operands.z << ')';
929 }
930 
931 void LaunchOp::print(OpAsmPrinter &p) {
932  if (getAsyncToken()) {
933  p << " async";
934  if (!getAsyncDependencies().empty())
935  p << " [" << getAsyncDependencies() << ']';
936  }
937  // Print the launch configuration.
938  if (hasClusterSize()) {
939  p << ' ' << getClustersKeyword();
940  printSizeAssignment(p, getClusterSize().value(),
941  getClusterSizeOperandValues().value(),
942  getClusterIds().value());
943  }
944  p << ' ' << getBlocksKeyword();
945  printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
946  getBlockIds());
947  p << ' ' << getThreadsKeyword();
948  printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
949  getThreadIds());
950  if (getDynamicSharedMemorySize())
951  p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
952  << getDynamicSharedMemorySize();
953 
954  // Print optional module attribute.
955  StringRef moduleAttrName = getModuleAttrName();
956  if (auto module = getModule()) {
957  p << ' ' << moduleAttrName << '(';
958  p.printSymbolName(*module);
959  p << ')';
960  }
961  // Print optional function attribute.
962  StringRef functionAttrName = getFunctionAttrName();
963  if (auto function = getFunction()) {
964  p << ' ' << functionAttrName << '(';
965  p.printSymbolName(*function);
966  p << ')';
967  }
968 
969  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
970  printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
971 
972  p << ' ';
973 
974  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
975  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
976  LaunchOp::getOperandSegmentSizeAttr(),
977  getNumWorkgroupAttributionsAttrName(),
978  moduleAttrName, functionAttrName});
979 }
980 
981 // Parse the size assignment blocks for blocks and threads. These have the form
982 // (%region_arg, %region_arg, %region_arg) in
983 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
984 // where %region_arg are percent-identifiers for the region arguments to be
985 // introduced further (SSA defs), and %operand are percent-identifiers for the
986 // SSA value uses.
987 static ParseResult
992  assert(indices.size() == 3 && "space for three indices expected");
995  /*allowResultNumber=*/false) ||
996  parser.parseKeyword("in") || parser.parseLParen())
997  return failure();
998  std::move(args.begin(), args.end(), indices.begin());
999 
1000  for (int i = 0; i < 3; ++i) {
1001  if (i != 0 && parser.parseComma())
1002  return failure();
1003  if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
1004  parser.parseEqual() || parser.parseOperand(sizes[i]))
1005  return failure();
1006  }
1007 
1008  return parser.parseRParen();
1009 }
1010 
1011 /// Parses a Launch operation.
1012 /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
1013 /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
1014 /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
1015 /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
1016 /// (`dynamic_shared_memory_size` ssa-use)?
1017 /// (`module(` symbol-ref-id `)`)?
1018 /// (`function(` symbol-ref-id `)`)?
1019 /// memory-attribution
1020 /// region attr-dict?
1021 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
1022 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
1023  // Sizes of the grid and block.
1025  sizes(LaunchOp::kNumConfigOperands);
1026 
1027  // Region arguments to be created.
1029  LaunchOp::kNumConfigRegionAttributes);
1030 
1031  // Parse optional async dependencies.
1033  Type asyncTokenType;
1034  if (failed(
1035  parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1036  parser.resolveOperands(asyncDependencies, asyncTokenType,
1037  result.operands))
1038  return failure();
1039  if (parser.getNumResults() > 0)
1040  result.types.push_back(asyncTokenType);
1041 
1042  bool hasCluster = false;
1043  if (succeeded(
1044  parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
1045  hasCluster = true;
1046  sizes.resize(9);
1047  regionArgs.resize(18);
1048  }
1050  MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1051 
1052  // Last three segment assigns the cluster size. In the region argument
1053  // list, this is last 6 arguments.
1054  if (hasCluster) {
1055  if (parseSizeAssignment(parser, sizesRef.drop_front(6),
1056  regionArgsRef.slice(15, 3),
1057  regionArgsRef.slice(12, 3)))
1058  return failure();
1059  }
1060  // Parse the size assignment segments: the first segment assigns grid sizes
1061  // and defines values for block identifiers; the second segment assigns block
1062  // sizes and defines values for thread identifiers. In the region argument
1063  // list, identifiers precede sizes, and block-related values precede
1064  // thread-related values.
1065  if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
1066  parseSizeAssignment(parser, sizesRef.take_front(3),
1067  regionArgsRef.slice(6, 3),
1068  regionArgsRef.slice(0, 3)) ||
1069  parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
1070  parseSizeAssignment(parser, sizesRef.drop_front(3),
1071  regionArgsRef.slice(9, 3),
1072  regionArgsRef.slice(3, 3)) ||
1073  parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
1074  result.operands))
1075  return failure();
1076 
1077  OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1078  bool hasDynamicSharedMemorySize = false;
1079  if (!parser.parseOptionalKeyword(
1080  LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1081  hasDynamicSharedMemorySize = true;
1082  if (parser.parseOperand(dynamicSharedMemorySize) ||
1083  parser.resolveOperand(dynamicSharedMemorySize,
1084  parser.getBuilder().getI32Type(),
1085  result.operands))
1086  return failure();
1087  }
1088 
1089  // Parse optional module attribute.
1090  StringRef moduleAttrName = getModuleAttrName(result.name);
1091  if (succeeded(parser.parseOptionalKeyword(moduleAttrName))) {
1092  FlatSymbolRefAttr moduleSymbol;
1093  if (parser.parseLParen() ||
1094  parser.parseAttribute(moduleSymbol, Type(), moduleAttrName,
1095  result.attributes) ||
1096  parser.parseRParen())
1097  return failure();
1098  }
1099  // Parse optional function attribute.
1100  StringRef functionAttrName = getFunctionAttrName(result.name);
1101  if (succeeded(parser.parseOptionalKeyword(functionAttrName))) {
1102  FlatSymbolRefAttr funcSymbol;
1103  if (parser.parseLParen() ||
1104  parser.parseAttribute(funcSymbol, Type(), functionAttrName,
1105  result.attributes) ||
1106  parser.parseRParen())
1107  return failure();
1108  }
1109 
1110  // Create the region arguments, it has kNumConfigRegionAttributes arguments
1111  // that correspond to block/thread identifiers and grid/block sizes, all
1112  // having `index` type, a variadic number of WorkGroup Attributions and
1113  // a variadic number of Private Attributions. The number of WorkGroup
1114  // Attributions is stored in the attr with name:
1115  // LaunchOp::getNumWorkgroupAttributionsAttrName().
1116  Type index = parser.getBuilder().getIndexType();
1118  LaunchOp::kNumConfigRegionAttributes + 6, index);
1119 
1120  SmallVector<OpAsmParser::Argument> regionArguments;
1121  for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1123  arg.ssaName = std::get<0>(ssaValueAndType);
1124  arg.type = std::get<1>(ssaValueAndType);
1125  regionArguments.push_back(arg);
1126  }
1127 
1128  Builder &builder = parser.getBuilder();
1129  // Parse workgroup memory attributions.
1130  if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1131  regionArguments)))
1132  return failure();
1133 
1134  // Store the number of operands we just parsed as the number of workgroup
1135  // memory attributions.
1136  unsigned numWorkgroupAttrs = regionArguments.size() -
1137  LaunchOp::kNumConfigRegionAttributes -
1138  (hasCluster ? 6 : 0);
1139  result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1140  builder.getI64IntegerAttr(numWorkgroupAttrs));
1141 
1142  // Parse private memory attributions.
1143  if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1144  regionArguments)))
1145  return failure();
1146 
1147  // Introduce the body region and parse it. The region has
1148  // kNumConfigRegionAttributes arguments that correspond to
1149  // block/thread identifiers and grid/block sizes, all having `index` type.
1150  Region *body = result.addRegion();
1151  if (parser.parseRegion(*body, regionArguments) ||
1152  parser.parseOptionalAttrDict(result.attributes))
1153  return failure();
1154 
1155  SmallVector<int32_t, 11> segmentSizes(11, 1);
1156  segmentSizes.front() = asyncDependencies.size();
1157 
1158  if (!hasCluster) {
1159  segmentSizes[7] = 0;
1160  segmentSizes[8] = 0;
1161  segmentSizes[9] = 0;
1162  }
1163  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1164  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1165  parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1166  return success();
1167 }
1168 
1169 /// Simplify the gpu.launch when the range of a thread or block ID is
1170 /// trivially known to be one.
1171 struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1173  LogicalResult matchAndRewrite(LaunchOp op,
1174  PatternRewriter &rewriter) const override {
1175  // If the range implies a single value for `id`, replace `id`'s uses by
1176  // zero.
1177  Value zero;
1178  bool simplified = false;
1179  auto constPropIdUses = [&](Value id, Value size) {
1180  // Check if size is trivially one.
1181  if (!matchPattern(size, m_One()))
1182  return;
1183  if (id.getUses().empty())
1184  return;
1185  if (!simplified) {
1186  // Create a zero value the first time.
1187  OpBuilder::InsertionGuard guard(rewriter);
1188  rewriter.setInsertionPointToStart(&op.getBody().front());
1189  zero =
1190  arith::ConstantIndexOp::create(rewriter, op.getLoc(), /*value=*/0);
1191  }
1192  rewriter.replaceAllUsesWith(id, zero);
1193  simplified = true;
1194  };
1195  constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1196  constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1197  constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1198  constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1199  constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1200  constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1201 
1202  return success(simplified);
1203  }
1204 };
1205 
1206 void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1207  MLIRContext *context) {
1208  rewrites.add<FoldLaunchArguments>(context);
1209 }
1210 
1211 /// Adds a new block argument that corresponds to buffers located in
1212 /// workgroup memory.
1213 BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1214  auto attrName = getNumWorkgroupAttributionsAttrName();
1215  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1216  (*this)->setAttr(attrName,
1217  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1218  return getBody().insertArgument(
1219  LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1220 }
1221 
1222 /// Adds a new block argument that corresponds to buffers located in
1223 /// private memory.
1224 BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1225  // Buffers on the private memory always come after buffers on the workgroup
1226  // memory.
1227  return getBody().addArgument(type, loc);
1228 }
1229 
1230 //===----------------------------------------------------------------------===//
1231 // LaunchFuncOp
1232 //===----------------------------------------------------------------------===//
1233 
1234 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1235  SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1236  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1237  ValueRange kernelOperands, Type asyncTokenType,
1238  ValueRange asyncDependencies,
1239  std::optional<KernelDim3> clusterSize) {
1240  assert(kernelSymbol.getNestedReferences().size() == 1 &&
1241  "expected a symbol reference with a single nested reference");
1242  result.addOperands(asyncDependencies);
1243  if (asyncTokenType)
1244  result.types.push_back(builder.getType<AsyncTokenType>());
1245 
1246  // Add grid and block sizes as op operands, followed by the data operands.
1247  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1248  getBlockSize.y, getBlockSize.z});
1249  if (clusterSize.has_value())
1250  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1251  if (dynamicSharedMemorySize)
1252  result.addOperands(dynamicSharedMemorySize);
1253  result.addOperands(kernelOperands);
1254 
1255  Properties &prop = result.getOrAddProperties<Properties>();
1256  prop.kernel = kernelSymbol;
1257  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1258  // Initialize the segment sizes to 1.
1259  llvm::fill(prop.operandSegmentSizes, 1);
1260  prop.operandSegmentSizes[0] = asyncDependencies.size();
1261  if (!clusterSize.has_value()) {
1262  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1263  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1264  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1265  }
1266  prop.operandSegmentSizes[segmentSizesLen - 3] =
1267  dynamicSharedMemorySize ? 1 : 0;
1268  prop.operandSegmentSizes[segmentSizesLen - 2] =
1269  static_cast<int32_t>(kernelOperands.size());
1270  prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1271 }
1272 
1273 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1274  GPUFuncOp kernelFunc, KernelDim3 gridSize,
1275  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1276  ValueRange kernelOperands, Type asyncTokenType,
1277  ValueRange asyncDependencies,
1278  std::optional<KernelDim3> clusterSize) {
1279  auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1280  auto kernelSymbol =
1281  SymbolRefAttr::get(kernelModule.getNameAttr(),
1282  {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1283  build(builder, result, kernelSymbol, gridSize, getBlockSize,
1284  dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1285  asyncDependencies, clusterSize);
1286 }
1287 
1288 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1289  SymbolRefAttr kernel, KernelDim3 gridSize,
1290  KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1291  ValueRange kernelOperands, Value asyncObject,
1292  std::optional<KernelDim3> clusterSize) {
1293  // Add grid and block sizes as op operands, followed by the data operands.
1294  result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1295  getBlockSize.y, getBlockSize.z});
1296  if (clusterSize.has_value())
1297  result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1298  if (dynamicSharedMemorySize)
1299  result.addOperands(dynamicSharedMemorySize);
1300  result.addOperands(kernelOperands);
1301  if (asyncObject)
1302  result.addOperands(asyncObject);
1303  Properties &prop = result.getOrAddProperties<Properties>();
1304  prop.kernel = kernel;
1305  size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1306  // Initialize the segment sizes to 1.
1307  llvm::fill(prop.operandSegmentSizes, 1);
1308  prop.operandSegmentSizes[0] = 0;
1309  if (!clusterSize.has_value()) {
1310  prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1311  prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1312  prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1313  }
1314  prop.operandSegmentSizes[segmentSizesLen - 3] =
1315  dynamicSharedMemorySize ? 1 : 0;
1316  prop.operandSegmentSizes[segmentSizesLen - 2] =
1317  static_cast<int32_t>(kernelOperands.size());
1318  prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1319 }
1320 
1321 StringAttr LaunchFuncOp::getKernelModuleName() {
1322  return getKernel().getRootReference();
1323 }
1324 
1325 StringAttr LaunchFuncOp::getKernelName() {
1326  return getKernel().getLeafReference();
1327 }
1328 
1329 unsigned LaunchFuncOp::getNumKernelOperands() {
1330  return getKernelOperands().size();
1331 }
1332 
1333 Value LaunchFuncOp::getKernelOperand(unsigned i) {
1334  return getKernelOperands()[i];
1335 }
1336 
1337 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1338  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1339  return KernelDim3{operands[0], operands[1], operands[2]};
1340 }
1341 
1342 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1343  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1344  return KernelDim3{operands[3], operands[4], operands[5]};
1345 }
1346 
1347 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1348  assert(hasClusterSize() &&
1349  "cluster size is not set, check hasClusterSize() first");
1350  auto operands = getOperands().drop_front(getAsyncDependencies().size());
1351  return KernelDim3{operands[6], operands[7], operands[8]};
1352 }
1353 
1354 LogicalResult LaunchFuncOp::verify() {
1355  auto module = (*this)->getParentOfType<ModuleOp>();
1356  if (!module)
1357  return emitOpError("expected to belong to a module");
1358 
1359  if (!module->getAttrOfType<UnitAttr>(
1360  GPUDialect::getContainerModuleAttrName()))
1361  return emitOpError("expected the closest surrounding module to have the '" +
1362  GPUDialect::getContainerModuleAttrName() +
1363  "' attribute");
1364 
1365  if (hasClusterSize()) {
1366  if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1367  getClusterSizeZ().getType() != getClusterSizeX().getType())
1368  return emitOpError()
1369  << "expects types of the cluster dimensions must be the same";
1370  }
1371 
1372  return success();
1373 }
1374 
1375 static ParseResult
1377  std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1378  Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1379  if (succeeded(parser.parseOptionalColon())) {
1380  if (parser.parseType(dimTy))
1381  return failure();
1382  } else {
1383  dimTy = IndexType::get(parser.getContext());
1384  }
1385  if (clusterValue.has_value()) {
1386  clusterXTy = clusterYTy = clusterZTy = dimTy;
1387  }
1388  return success();
1389 }
1390 
1391 static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1392  Value clusterValue, Type clusterXTy,
1393  Type clusterYTy, Type clusterZTy) {
1394  if (!dimTy.isIndex())
1395  printer << ": " << dimTy;
1396 }
1397 
1398 static ParseResult parseLaunchFuncOperands(
1399  OpAsmParser &parser,
1401  SmallVectorImpl<Type> &argTypes) {
1402  if (parser.parseOptionalKeyword("args"))
1403  return success();
1404 
1405  auto parseElement = [&]() -> ParseResult {
1406  return failure(parser.parseOperand(argNames.emplace_back()) ||
1407  parser.parseColonType(argTypes.emplace_back()));
1408  };
1409 
1411  parseElement, " in argument list");
1412 }
1413 
1415  OperandRange operands, TypeRange types) {
1416  if (operands.empty())
1417  return;
1418  printer << "args(";
1419  llvm::interleaveComma(llvm::zip_equal(operands, types), printer,
1420  [&](const auto &pair) {
1421  auto [operand, type] = pair;
1422  printer << operand << " : " << type;
1423  });
1424  printer << ")";
1425 }
1426 
1427 //===----------------------------------------------------------------------===//
1428 // ShuffleOp
1429 //===----------------------------------------------------------------------===//
1430 
1431 void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1432  int32_t offset, int32_t width, ShuffleMode mode) {
1433  build(builder, result, value,
1434  arith::ConstantOp::create(builder, result.location,
1435  builder.getI32IntegerAttr(offset)),
1436  arith::ConstantOp::create(builder, result.location,
1437  builder.getI32IntegerAttr(width)),
1438  mode);
1439 }
1440 
1441 //===----------------------------------------------------------------------===//
1442 // RotateOp
1443 //===----------------------------------------------------------------------===//
1444 
1445 LogicalResult RotateOp::verify() {
1446  uint32_t offset = getOffset();
1447  uint32_t width = getWidth();
1448 
1449  if (offset >= width) {
1450  return emitOpError() << "offset must be in the range [0, " << width << ")";
1451  }
1452 
1453  return success();
1454 }
1455 
1456 //===----------------------------------------------------------------------===//
1457 // BarrierOp
1458 //===----------------------------------------------------------------------===//
1459 
1460 namespace {
1461 
1462 /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1463 LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1464  PatternRewriter &rewriter) {
1465  if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1466  rewriter.eraseOp(op);
1467  return success();
1468  }
1469  return failure();
1470 }
1471 
1472 } // end anonymous namespace
1473 
1474 void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1475  MLIRContext *context) {
1476  results.add(eraseRedundantGpuBarrierOps);
1477 }
1478 
1479 //===----------------------------------------------------------------------===//
1480 // GPUFuncOp
1481 //===----------------------------------------------------------------------===//
1482 
1483 /// Adds a new block argument that corresponds to buffers located in
1484 /// workgroup memory.
1485 BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1486  auto attrName = getNumWorkgroupAttributionsAttrName();
1487  auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1488  (*this)->setAttr(attrName,
1489  IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1490  return getBody().insertArgument(
1491  getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1492 }
1493 
1494 /// Adds a new block argument that corresponds to buffers located in
1495 /// private memory.
1496 BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1497  // Buffers on the private memory always come after buffers on the workgroup
1498  // memory.
1499  return getBody().addArgument(type, loc);
1500 }
1501 
1502 void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1503  StringRef name, FunctionType type,
1504  TypeRange workgroupAttributions,
1505  TypeRange privateAttributions,
1506  ArrayRef<NamedAttribute> attrs) {
1507  OpBuilder::InsertionGuard g(builder);
1508 
1510  builder.getStringAttr(name));
1511  result.addAttribute(getFunctionTypeAttrName(result.name),
1512  TypeAttr::get(type));
1513  result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1514  builder.getI64IntegerAttr(workgroupAttributions.size()));
1515  result.addAttributes(attrs);
1516  Region *body = result.addRegion();
1517  Block *entryBlock = builder.createBlock(body);
1518 
1519  // TODO: Allow passing in proper locations here.
1520  for (Type argTy : type.getInputs())
1521  entryBlock->addArgument(argTy, result.location);
1522  for (Type argTy : workgroupAttributions)
1523  entryBlock->addArgument(argTy, result.location);
1524  for (Type argTy : privateAttributions)
1525  entryBlock->addArgument(argTy, result.location);
1526 }
1527 
1528 /// Parses a GPU function memory attribution.
1529 ///
1530 /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1531 /// (`private` `(` ssa-id-and-type-list `)`)?
1532 ///
1533 /// Note that this function parses only one of the two similar parts, with the
1534 /// keyword provided as argument.
1535 static ParseResult
1536 parseAttributions(OpAsmParser &parser, StringRef keyword,
1538  Attribute &attributionAttrs) {
1539  // If we could not parse the keyword, just assume empty list and succeed.
1540  if (failed(parser.parseOptionalKeyword(keyword)))
1541  return success();
1542 
1543  size_t existingArgs = args.size();
1544  ParseResult result =
1546  /*allowType=*/true, /*allowAttrs=*/true);
1547  if (failed(result))
1548  return result;
1549 
1550  bool hadAttrs = llvm::any_of(ArrayRef(args).drop_front(existingArgs),
1551  [](const OpAsmParser::Argument &arg) -> bool {
1552  return arg.attrs && !arg.attrs.empty();
1553  });
1554  if (!hadAttrs) {
1555  attributionAttrs = nullptr;
1556  return result;
1557  }
1558 
1559  Builder &builder = parser.getBuilder();
1560  SmallVector<Attribute> attributionAttrsVec;
1561  for (const auto &argument : ArrayRef(args).drop_front(existingArgs)) {
1562  if (!argument.attrs)
1563  attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1564  else
1565  attributionAttrsVec.push_back(argument.attrs);
1566  }
1567  attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1568  return result;
1569 }
1570 
1571 /// Parses a GPU function.
1572 ///
1573 /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1574 /// (`->` function-result-list)? memory-attribution `kernel`?
1575 /// function-attributes? region
1576 ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1578  SmallVector<DictionaryAttr> resultAttrs;
1579  SmallVector<Type> resultTypes;
1580  bool isVariadic;
1581 
1582  // Parse the function name.
1583  StringAttr nameAttr;
1584  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1585  result.attributes))
1586  return failure();
1587 
1588  auto signatureLocation = parser.getCurrentLocation();
1590  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1591  resultAttrs)))
1592  return failure();
1593 
1594  if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1595  return parser.emitError(signatureLocation)
1596  << "gpu.func requires named arguments";
1597 
1598  // Construct the function type. More types will be added to the region, but
1599  // not to the function type.
1600  Builder &builder = parser.getBuilder();
1601 
1602  SmallVector<Type> argTypes;
1603  for (auto &arg : entryArgs)
1604  argTypes.push_back(arg.type);
1605  auto type = builder.getFunctionType(argTypes, resultTypes);
1606  result.addAttribute(getFunctionTypeAttrName(result.name),
1607  TypeAttr::get(type));
1608 
1610  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1611  getResAttrsAttrName(result.name));
1612 
1613  Attribute workgroupAttributionAttrs;
1614  // Parse workgroup memory attributions.
1615  if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1616  entryArgs, workgroupAttributionAttrs)))
1617  return failure();
1618 
1619  // Store the number of operands we just parsed as the number of workgroup
1620  // memory attributions.
1621  unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1622  result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1623  builder.getI64IntegerAttr(numWorkgroupAttrs));
1624  if (workgroupAttributionAttrs)
1625  result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1626  workgroupAttributionAttrs);
1627 
1628  Attribute privateAttributionAttrs;
1629  // Parse private memory attributions.
1630  if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1631  entryArgs, privateAttributionAttrs)))
1632  return failure();
1633  if (privateAttributionAttrs)
1634  result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1635  privateAttributionAttrs);
1636 
1637  // Parse the kernel attribute if present.
1638  if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1639  result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1640  builder.getUnitAttr());
1641 
1642  // Parse attributes.
1644  return failure();
1645 
1646  // Parse the region. If no argument names were provided, take all names
1647  // (including those of attributions) from the entry block.
1648  auto *body = result.addRegion();
1649  return parser.parseRegion(*body, entryArgs);
1650 }
1651 
1652 static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1653  ArrayRef<BlockArgument> values,
1654  ArrayAttr attributes) {
1655  if (values.empty())
1656  return;
1657 
1658  p << ' ' << keyword << '(';
1659  llvm::interleaveComma(
1660  llvm::enumerate(values), p, [&p, attributes](auto pair) {
1661  BlockArgument v = pair.value();
1662  p << v << " : " << v.getType();
1663 
1664  size_t attributionIndex = pair.index();
1665  DictionaryAttr attrs;
1666  if (attributes && attributionIndex < attributes.size())
1667  attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1668  if (attrs)
1669  p.printOptionalAttrDict(attrs.getValue());
1670  });
1671  p << ')';
1672 }
1673 
1674 void GPUFuncOp::print(OpAsmPrinter &p) {
1675  p << ' ';
1676  p.printSymbolName(getName());
1677 
1678  FunctionType type = getFunctionType();
1679  function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1680  /*isVariadic=*/false,
1681  type.getResults());
1682 
1683  printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1684  getWorkgroupAttribAttrs().value_or(nullptr));
1685  printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1686  getPrivateAttribAttrs().value_or(nullptr));
1687  if (isKernel())
1688  p << ' ' << getKernelKeyword();
1689 
1691  p, *this,
1692  {getNumWorkgroupAttributionsAttrName(),
1693  GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1694  getArgAttrsAttrName(), getResAttrsAttrName(),
1695  getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1696  p << ' ';
1697  p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1698 }
1699 
1700 static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1701  StringAttr attrName) {
1702  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1703  if (!allAttrs || index >= allAttrs.size())
1704  return DictionaryAttr();
1705  return llvm::cast<DictionaryAttr>(allAttrs[index]);
1706 }
1707 
1708 DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1709  return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1710 }
1711 
1712 DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1713  return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1714 }
1715 
1716 static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1717  DictionaryAttr value, StringAttr attrName) {
1718  MLIRContext *ctx = op.getContext();
1719  auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1720  SmallVector<Attribute> elements;
1721  if (allAttrs)
1722  elements.append(allAttrs.begin(), allAttrs.end());
1723  while (elements.size() <= index)
1724  elements.push_back(DictionaryAttr::get(ctx));
1725  if (!value)
1726  elements[index] = DictionaryAttr::get(ctx);
1727  else
1728  elements[index] = value;
1729  ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1730  op->setAttr(attrName, newValue);
1731 }
1732 
1733 void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1734  DictionaryAttr value) {
1735  setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1736 }
1737 
1738 void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1739  DictionaryAttr value) {
1740  setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1741 }
1742 
1743 static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1744  StringAttr name, StringAttr attrsName) {
1745  DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1746  if (!dict)
1747  return Attribute();
1748  return dict.get(name);
1749 }
1750 
1751 Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1752  StringAttr name) {
1753  assert(index < getNumWorkgroupAttributions() &&
1754  "index must map to a workgroup attribution");
1755  return getAttributionAttr(*this, index, name,
1756  getWorkgroupAttribAttrsAttrName());
1757 }
1758 
1759 Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1760  StringAttr name) {
1761  assert(index < getNumPrivateAttributions() &&
1762  "index must map to a private attribution");
1763  return getAttributionAttr(*this, index, name,
1764  getPrivateAttribAttrsAttrName());
1765 }
1766 
1767 static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1768  Attribute value, StringAttr attrsName) {
1769  MLIRContext *ctx = op.getContext();
1771  DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1772  if (oldDict)
1773  elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1774 
1775  bool found = false;
1776  bool mustSort = true;
1777  for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1778  if (elems[i].getName() == name) {
1779  found = true;
1780  if (!value) {
1781  std::swap(elems[i], elems[elems.size() - 1]);
1782  elems.pop_back();
1783  } else {
1784  mustSort = false;
1785  elems[i] = NamedAttribute(elems[i].getName(), value);
1786  }
1787  break;
1788  }
1789  }
1790  if (!found) {
1791  if (!value)
1792  return;
1793  elems.emplace_back(name, value);
1794  }
1795  if (mustSort) {
1796  DictionaryAttr::sortInPlace(elems);
1797  }
1798  auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1799  setAttributionAttrs(op, index, newDict, attrsName);
1800 }
1801 
1802 void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1803  Attribute value) {
1804  assert(index < getNumWorkgroupAttributions() &&
1805  "index must map to a workgroup attribution");
1806  setAttributionAttr(*this, index, name, value,
1807  getWorkgroupAttribAttrsAttrName());
1808 }
1809 
1810 void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1811  Attribute value) {
1812  assert(index < getNumPrivateAttributions() &&
1813  "index must map to a private attribution");
1814  setAttributionAttr(*this, index, name, value,
1815  getPrivateAttribAttrsAttrName());
1816 }
1817 
1818 LogicalResult GPUFuncOp::verifyType() {
1819  if (isKernel() && getFunctionType().getNumResults() != 0)
1820  return emitOpError() << "expected void return type for kernel function";
1821 
1822  return success();
1823 }
1824 
1825 /// Verifies the body of the function.
1826 LogicalResult GPUFuncOp::verifyBody() {
1827  if (empty())
1828  return emitOpError() << "expected body with at least one block";
1829  unsigned numFuncArguments = getNumArguments();
1830  unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1831  unsigned numBlockArguments = front().getNumArguments();
1832  if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1833  return emitOpError() << "expected at least "
1834  << numFuncArguments + numWorkgroupAttributions
1835  << " arguments to body region";
1836 
1837  ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1838  for (unsigned i = 0; i < numFuncArguments; ++i) {
1839  Type blockArgType = front().getArgument(i).getType();
1840  if (funcArgTypes[i] != blockArgType)
1841  return emitOpError() << "expected body region argument #" << i
1842  << " to be of type " << funcArgTypes[i] << ", got "
1843  << blockArgType;
1844  }
1845 
1846  if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1847  GPUDialect::getWorkgroupAddressSpace())) ||
1848  failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1849  GPUDialect::getPrivateAddressSpace())))
1850  return failure();
1851 
1852  return success();
1853 }
1854 
1855 //===----------------------------------------------------------------------===//
1856 // ReturnOp
1857 //===----------------------------------------------------------------------===//
1858 
1859 LogicalResult gpu::ReturnOp::verify() {
1860  GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1861 
1862  FunctionType funType = function.getFunctionType();
1863 
1864  if (funType.getNumResults() != getOperands().size())
1865  return emitOpError()
1866  .append("expected ", funType.getNumResults(), " result operands")
1867  .attachNote(function.getLoc())
1868  .append("return type declared here");
1869 
1870  for (const auto &pair : llvm::enumerate(
1871  llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1872  auto [type, operand] = pair.value();
1873  if (type != operand.getType())
1874  return emitOpError() << "unexpected type `" << operand.getType()
1875  << "' for operand #" << pair.index();
1876  }
1877  return success();
1878 }
1879 
1880 //===----------------------------------------------------------------------===//
1881 // GPUModuleOp
1882 //===----------------------------------------------------------------------===//
1883 
1884 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1885  StringRef name, ArrayAttr targets,
1886  Attribute offloadingHandler) {
1887  result.addRegion()->emplaceBlock();
1888  Properties &props = result.getOrAddProperties<Properties>();
1889  if (targets)
1890  props.targets = targets;
1891  props.setSymName(builder.getStringAttr(name));
1892  props.offloadingHandler = offloadingHandler;
1893 }
1894 
1895 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1896  StringRef name, ArrayRef<Attribute> targets,
1897  Attribute offloadingHandler) {
1898  build(builder, result, name,
1899  targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1900  offloadingHandler);
1901 }
1902 
1903 bool GPUModuleOp::hasTarget(Attribute target) {
1904  if (ArrayAttr targets = getTargetsAttr())
1905  return llvm::count(targets.getValue(), target);
1906  return false;
1907 }
1908 
1909 void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1910  ArrayAttr &targetsAttr = getProperties().targets;
1911  SmallVector<Attribute> targetsVector(targets);
1912  targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1913 }
1914 
1915 LogicalResult GPUModuleOp::verify() {
1916  auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1917 
1918  if (!targets)
1919  return success();
1920 
1921  for (auto target : targets) {
1922  if (auto verifyTargetAttr =
1923  llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1924  if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1925  return failure();
1926  }
1927  }
1928  return success();
1929 }
1930 
1931 //===----------------------------------------------------------------------===//
1932 // GPUBinaryOp
1933 //===----------------------------------------------------------------------===//
1934 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1935  Attribute offloadingHandler, ArrayAttr objects) {
1936  auto &properties = result.getOrAddProperties<Properties>();
1937  result.attributes.push_back(builder.getNamedAttr(
1938  SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1939  properties.objects = objects;
1940  if (offloadingHandler)
1941  properties.offloadingHandler = offloadingHandler;
1942  else
1943  properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1944 }
1945 
1946 void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1947  Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1948  build(builder, result, name, offloadingHandler,
1949  objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1950 }
1951 
1952 static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1953  Attribute &offloadingHandler) {
1954  if (succeeded(parser.parseOptionalLess())) {
1955  if (parser.parseAttribute(offloadingHandler))
1956  return failure();
1957  if (parser.parseGreater())
1958  return failure();
1959  }
1960  if (!offloadingHandler)
1961  offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1962  return success();
1963 }
1964 
1966  Attribute offloadingHandler) {
1967  if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1968  printer << '<' << offloadingHandler << '>';
1969 }
1970 
1971 //===----------------------------------------------------------------------===//
1972 // GPUMemcpyOp
1973 //===----------------------------------------------------------------------===//
1974 
1975 LogicalResult MemcpyOp::verify() {
1976  auto srcType = getSrc().getType();
1977  auto dstType = getDst().getType();
1978 
1979  if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1980  return emitOpError("arguments have incompatible element type");
1981 
1982  if (failed(verifyCompatibleShape(srcType, dstType)))
1983  return emitOpError("arguments have incompatible shape");
1984 
1985  return success();
1986 }
1987 
1988 namespace {
1989 
1990 /// Erases a common case of copy ops where a destination value is used only by
1991 /// the copy op, alloc and dealloc ops.
1992 struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1993  using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1994 
1995  LogicalResult matchAndRewrite(MemcpyOp op,
1996  PatternRewriter &rewriter) const override {
1997  Value dest = op.getDst();
1998  Operation *destDefOp = dest.getDefiningOp();
1999  // `dest` must be defined by an op having Allocate memory effect in order to
2000  // perform the folding.
2001  if (!destDefOp ||
2002  !hasSingleEffect<MemoryEffects::Allocate>(destDefOp, dest))
2003  return failure();
2004  // We can erase `op` iff `dest` has no other use apart from its
2005  // use by `op` and dealloc ops.
2006  if (llvm::any_of(dest.getUsers(), [op, dest](Operation *user) {
2007  return user != op &&
2008  !hasSingleEffect<MemoryEffects::Free>(user, dest);
2009  }))
2010  return failure();
2011  // We can perform the folding if and only if op has a single async
2012  // dependency and produces an async token as result, or if it does not have
2013  // any async dependency and does not produce any async token result.
2014  if (op.getAsyncDependencies().size() > 1 ||
2015  ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
2016  (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
2017  return failure();
2018  rewriter.replaceOp(op, op.getAsyncDependencies());
2019  return success();
2020  }
2021 };
2022 
2023 } // end anonymous namespace
2024 
2025 void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2026  MLIRContext *context) {
2027  results.add<EraseTrivialCopyOp>(context);
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // GPU_SubgroupMmaLoadMatrixOp
2032 //===----------------------------------------------------------------------===//
2033 
2034 LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2035  auto srcType = getSrcMemref().getType();
2036  auto resType = getRes().getType();
2037  auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
2038  auto operand = resMatrixType.getOperand();
2039  auto srcMemrefType = llvm::cast<MemRefType>(srcType);
2040 
2041  if (!srcMemrefType.isLastDimUnitStride())
2042  return emitError(
2043  "expected source memref most minor dim must have unit stride");
2044 
2045  if (operand != "AOp" && operand != "BOp" && operand != "COp")
2046  return emitError("only AOp, BOp and COp can be loaded");
2047 
2048  return success();
2049 }
2050 
2051 //===----------------------------------------------------------------------===//
2052 // GPU_SubgroupMmaStoreMatrixOp
2053 //===----------------------------------------------------------------------===//
2054 
2055 LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2056  auto srcType = getSrc().getType();
2057  auto dstType = getDstMemref().getType();
2058  auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
2059  auto dstMemrefType = llvm::cast<MemRefType>(dstType);
2060 
2061  if (!dstMemrefType.isLastDimUnitStride())
2062  return emitError(
2063  "expected destination memref most minor dim must have unit stride");
2064 
2065  if (srcMatrixType.getOperand() != "COp")
2066  return emitError(
2067  "expected the operand matrix being stored to have 'COp' operand type");
2068 
2069  return success();
2070 }
2071 
2072 //===----------------------------------------------------------------------===//
2073 // GPU_SubgroupMmaComputeOp
2074 //===----------------------------------------------------------------------===//
2075 
2076 LogicalResult SubgroupMmaComputeOp::verify() {
2077  enum OperandMap { A, B, C };
2078  SmallVector<MMAMatrixType, 3> opTypes;
2079  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
2080  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
2081  opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
2082 
2083  if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2084  opTypes[C].getOperand() != "COp")
2085  return emitError("operands must be in the order AOp, BOp, COp");
2086 
2087  ArrayRef<int64_t> aShape, bShape, cShape;
2088  aShape = opTypes[A].getShape();
2089  bShape = opTypes[B].getShape();
2090  cShape = opTypes[C].getShape();
2091 
2092  if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2093  bShape[1] != cShape[1])
2094  return emitError("operand shapes do not satisfy matmul constraints");
2095 
2096  return success();
2097 }
2098 
2099 LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2100  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2101  return memref::foldMemRefCast(*this);
2102 }
2103 
2104 LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2105  SmallVectorImpl<::mlir::OpFoldResult> &results) {
2106  return memref::foldMemRefCast(*this);
2107 }
2108 
2109 //===----------------------------------------------------------------------===//
2110 // GPU_WaitOp
2111 //===----------------------------------------------------------------------===//
2112 
2113 namespace {
2114 
2115 /// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2116 /// %t = gpu.wait async [] // No async dependencies.
2117 /// ... gpu.wait ... [%t, ...] // %t can be removed.
2118 struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2119 public:
2120  using OpRewritePattern::OpRewritePattern;
2121 
2122  LogicalResult matchAndRewrite(WaitOp op,
2123  PatternRewriter &rewriter) const final {
2124  auto predicate = [](Value value) {
2125  auto waitOp = value.getDefiningOp<WaitOp>();
2126  return waitOp && waitOp->getNumOperands() == 0;
2127  };
2128  if (llvm::none_of(op.getAsyncDependencies(), predicate))
2129  return failure();
2130  SmallVector<Value> validOperands;
2131  for (Value operand : op->getOperands()) {
2132  if (predicate(operand))
2133  continue;
2134  validOperands.push_back(operand);
2135  }
2136  rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2137  return success();
2138  }
2139 };
2140 
2141 /// Simplify trivial gpu.wait ops for the following patterns.
2142 /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2143 /// dependencies).
2144 /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2145 /// %t0.
2146 /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2147 /// dependencies nor return any token.
2148 struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2149 public:
2150  using OpRewritePattern::OpRewritePattern;
2151 
2152  LogicalResult matchAndRewrite(WaitOp op,
2153  PatternRewriter &rewriter) const final {
2154  // Erase gpu.wait ops that neither have any async dependencies nor return
2155  // any async token.
2156  if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2157  rewriter.eraseOp(op);
2158  return success();
2159  }
2160  // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2161  if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2162  op.getAsyncToken()) {
2163  rewriter.replaceOp(op, op.getAsyncDependencies());
2164  return success();
2165  }
2166  // Erase %t = gpu.wait async ... ops, where %t has no uses.
2167  if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2168  rewriter.eraseOp(op);
2169  return success();
2170  }
2171  return failure();
2172  }
2173 };
2174 
2175 } // end anonymous namespace
2176 
2177 void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2178  MLIRContext *context) {
2179  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2180 }
2181 
2182 //===----------------------------------------------------------------------===//
2183 // GPU_AllocOp
2184 //===----------------------------------------------------------------------===//
2185 
2186 LogicalResult AllocOp::verify() {
2187  auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2188 
2189  if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2190  return emitOpError("dimension operand count does not equal memref "
2191  "dynamic dimension count");
2192 
2193  unsigned numSymbols = 0;
2194  if (!memRefType.getLayout().isIdentity())
2195  numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2196  if (getSymbolOperands().size() != numSymbols) {
2197  return emitOpError(
2198  "symbol operand count does not equal memref symbol count");
2199  }
2200 
2201  return success();
2202 }
2203 
2204 namespace {
2205 
2206 /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2207 /// `memref::AllocOp`.
2208 struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2209  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2210 
2211  LogicalResult matchAndRewrite(memref::DimOp dimOp,
2212  PatternRewriter &rewriter) const override {
2213  std::optional<int64_t> index = dimOp.getConstantIndex();
2214  if (!index)
2215  return failure();
2216 
2217  auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2218  if (!memrefType || index.value() >= memrefType.getRank() ||
2219  !memrefType.isDynamicDim(index.value()))
2220  return failure();
2221 
2222  auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2223  if (!alloc)
2224  return failure();
2225 
2226  Value substituteOp = *(alloc.getDynamicSizes().begin() +
2227  memrefType.getDynamicDimIndex(index.value()));
2228  rewriter.replaceOp(dimOp, substituteOp);
2229  return success();
2230  }
2231 };
2232 
2233 } // namespace
2234 
2235 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2236  MLIRContext *context) {
2237  results.add<SimplifyDimOfAllocOp>(context);
2238 }
2239 
2240 //===----------------------------------------------------------------------===//
2241 // GPU object attribute
2242 //===----------------------------------------------------------------------===//
2243 
2244 LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2245  Attribute target, CompilationTarget format,
2246  StringAttr object, DictionaryAttr properties,
2247  KernelTableAttr kernels) {
2248  if (!target)
2249  return emitError() << "the target attribute cannot be null";
2250  if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2251  return success();
2252  return emitError() << "the target attribute must implement or promise the "
2253  "`gpu::TargetAttrInterface`";
2254 }
2255 
2256 namespace {
2257 ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2258  StringAttr &object) {
2259  std::optional<CompilationTarget> formatResult;
2260  StringRef enumKeyword;
2261  auto loc = odsParser.getCurrentLocation();
2262  if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2263  formatResult = CompilationTarget::Fatbin;
2264  if (!formatResult &&
2265  (formatResult =
2266  gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2267  odsParser.parseEqual())
2268  return odsParser.emitError(loc, "expected an equal sign");
2269  if (!formatResult)
2270  return odsParser.emitError(loc, "expected keyword for GPU object format");
2271  FailureOr<StringAttr> objectResult =
2272  FieldParser<StringAttr>::parse(odsParser);
2273  if (failed(objectResult))
2274  return odsParser.emitError(odsParser.getCurrentLocation(),
2275  "failed to parse GPU_ObjectAttr parameter "
2276  "'object' which is to be a `StringAttr`");
2277  format = *formatResult;
2278  object = *objectResult;
2279  return success();
2280 }
2281 
2282 void printObject(AsmPrinter &odsParser, CompilationTarget format,
2283  StringAttr object) {
2284  if (format != CompilationTarget::Fatbin)
2285  odsParser << stringifyEnum(format) << " = ";
2286  odsParser << object;
2287 }
2288 } // namespace
2289 
2290 //===----------------------------------------------------------------------===//
2291 // GPU select object attribute
2292 //===----------------------------------------------------------------------===//
2293 
2294 LogicalResult
2295 gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2296  Attribute target) {
2297  // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2298  if (target) {
2299  if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2300  if (intAttr.getInt() < 0) {
2301  return emitError() << "the object index must be positive";
2302  }
2303  } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2304  return emitError()
2305  << "the target attribute must be a GPU Target attribute";
2306  }
2307  }
2308  return success();
2309 }
2310 
2311 //===----------------------------------------------------------------------===//
2312 // DynamicSharedMemoryOp
2313 //===----------------------------------------------------------------------===//
2314 
2315 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2316  if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2317  return emitOpError() << "must be inside an op with symbol table";
2318 
2319  MemRefType memrefType = getResultMemref().getType();
2320  // Check address space
2321  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2322  return emitOpError() << "address space must be "
2323  << gpu::AddressSpaceAttr::getMnemonic() << "<"
2324  << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2325  }
2326  if (memrefType.hasStaticShape()) {
2327  return emitOpError() << "result memref type must be memref<?xi8, "
2328  "#gpu.address_space<workgroup>>";
2329  }
2330  return success();
2331 }
2332 
2333 //===----------------------------------------------------------------------===//
2334 // GPU WarpExecuteOnLane0Op
2335 //===----------------------------------------------------------------------===//
2336 
2337 void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2338  p << "(" << getLaneid() << ")";
2339 
2340  SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2341  auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2342  p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2343 
2344  if (!getArgs().empty())
2345  p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2346  if (!getResults().empty())
2347  p << " -> (" << getResults().getTypes() << ')';
2348  p << " ";
2349  p.printRegion(getRegion(),
2350  /*printEntryBlockArgs=*/true,
2351  /*printBlockTerminators=*/!getResults().empty());
2352  p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2353 }
2354 
2355 ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2356  OperationState &result) {
2357  // Create the region.
2358  result.regions.reserve(1);
2359  Region *warpRegion = result.addRegion();
2360 
2361  auto &builder = parser.getBuilder();
2362  OpAsmParser::UnresolvedOperand laneId;
2363 
2364  // Parse predicate operand.
2365  if (parser.parseLParen() ||
2366  parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2367  parser.parseRParen())
2368  return failure();
2369 
2370  int64_t warpSize;
2371  if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2372  parser.parseRSquare())
2373  return failure();
2374  result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2375  builder.getContext())),
2376  builder.getI64IntegerAttr(warpSize));
2377 
2378  if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2379  return failure();
2380 
2381  llvm::SMLoc inputsOperandsLoc;
2382  SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2383  SmallVector<Type> inputTypes;
2384  if (succeeded(parser.parseOptionalKeyword("args"))) {
2385  if (parser.parseLParen())
2386  return failure();
2387 
2388  inputsOperandsLoc = parser.getCurrentLocation();
2389  if (parser.parseOperandList(inputsOperands) ||
2390  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2391  return failure();
2392  }
2393  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2394  result.operands))
2395  return failure();
2396 
2397  // Parse optional results type list.
2398  if (parser.parseOptionalArrowTypeList(result.types))
2399  return failure();
2400  // Parse the region.
2401  if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2402  /*argTypes=*/{}))
2403  return failure();
2404  WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2405 
2406  // Parse the optional attribute list.
2407  if (parser.parseOptionalAttrDict(result.attributes))
2408  return failure();
2409  return success();
2410 }
2411 
2412 void WarpExecuteOnLane0Op::getSuccessorRegions(
2413  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2414  if (!point.isParent()) {
2415  regions.push_back(RegionSuccessor(getResults()));
2416  return;
2417  }
2418 
2419  // The warp region is always executed
2420  regions.push_back(RegionSuccessor(&getWarpRegion()));
2421 }
2422 
2423 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2424  TypeRange resultTypes, Value laneId,
2425  int64_t warpSize) {
2426  build(builder, result, resultTypes, laneId, warpSize,
2427  /*operands=*/{}, /*argTypes=*/{});
2428 }
2429 
2430 void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2431  TypeRange resultTypes, Value laneId,
2432  int64_t warpSize, ValueRange args,
2433  TypeRange blockArgTypes) {
2434  result.addOperands(laneId);
2435  result.addAttribute(getAttributeNames()[0],
2436  builder.getI64IntegerAttr(warpSize));
2437  result.addTypes(resultTypes);
2438  result.addOperands(args);
2439  assert(args.size() == blockArgTypes.size());
2440  OpBuilder::InsertionGuard guard(builder);
2441  Region *warpRegion = result.addRegion();
2442  Block *block = builder.createBlock(warpRegion);
2443  for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2444  block->addArgument(type, arg.getLoc());
2445 }
2446 
2447 /// Helper check if the distributed vector type is consistent with the expanded
2448 /// type and distributed size.
2449 static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2450  int64_t warpSize, Operation *op) {
2451  // If the types matches there is no distribution.
2452  if (expanded == distributed)
2453  return success();
2454  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2455  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2456  if (!expandedVecType || !distributedVecType)
2457  return op->emitOpError("expected vector type for distributed operands.");
2458  if (expandedVecType.getRank() != distributedVecType.getRank() ||
2459  expandedVecType.getElementType() != distributedVecType.getElementType())
2460  return op->emitOpError(
2461  "expected distributed vectors to have same rank and element type.");
2462 
2463  SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2464  for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2465  int64_t eDim = expandedVecType.getDimSize(i);
2466  int64_t dDim = distributedVecType.getDimSize(i);
2467  if (eDim == dDim)
2468  continue;
2469  if (eDim % dDim != 0)
2470  return op->emitOpError()
2471  << "expected expanded vector dimension #" << i << " (" << eDim
2472  << ") to be a multipler of the distributed vector dimension ("
2473  << dDim << ")";
2474  scales[i] = eDim / dDim;
2475  }
2476  if (std::accumulate(scales.begin(), scales.end(), 1,
2477  std::multiplies<int64_t>()) != warpSize)
2478  return op->emitOpError()
2479  << "incompatible distribution dimensions from " << expandedVecType
2480  << " to " << distributedVecType << " with warp size = " << warpSize;
2481 
2482  return success();
2483 }
2484 
2485 LogicalResult WarpExecuteOnLane0Op::verify() {
2486  if (getArgs().size() != getWarpRegion().getNumArguments())
2487  return emitOpError(
2488  "expected same number op arguments and block arguments.");
2489  gpu::YieldOp yield = getTerminator();
2490  if (yield.getNumOperands() != getNumResults())
2491  return emitOpError(
2492  "expected same number of yield operands and return values.");
2493  int64_t warpSize = getWarpSize();
2494  for (auto [regionArg, arg] :
2495  llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2496  if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2497  warpSize, getOperation())))
2498  return failure();
2499  }
2500  for (auto [yieldOperand, result] :
2501  llvm::zip_equal(yield.getOperands(), getResults())) {
2502  if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2503  warpSize, getOperation())))
2504  return failure();
2505  }
2506  return success();
2507 }
2508 bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2509  return succeeded(
2510  verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2511 }
2512 
2513 gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
2514  return cast<gpu::YieldOp>(getBody()->getTerminator());
2515 }
2516 
2517 //===----------------------------------------------------------------------===//
2518 // GPU_SubgroupBroadcastOp
2519 //===----------------------------------------------------------------------===//
2520 
2521 void gpu::SubgroupBroadcastOp::inferResultRanges(
2522  ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
2523  setResultRange(getResult(), argRanges.front());
2524 }
2525 
2526 Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() {
2527  switch (getBroadcastType()) {
2528  case BroadcastType::first_active_lane:
2529  // Cannot speculate first_lane broadcast, because speculating it across
2530  // control flow can change the active lanes.
2531  return Speculation::NotSpeculatable;
2532  case BroadcastType::specific_lane:
2533  // Speculation should be safe as long as we inside structured control flow.
2534  return Speculation::Speculatable;
2535  }
2536 }
2537 
2538 LogicalResult gpu::SubgroupBroadcastOp::verify() {
2539  switch (getBroadcastType()) {
2540  case BroadcastType::first_active_lane:
2541  if (getLane())
2542  return emitOpError()
2543  << "lane can only be specified for `specific_lane` broadcast";
2544  return success();
2545  case BroadcastType::specific_lane:
2546  if (!getLane())
2547  return emitOpError()
2548  << "lane must be specified for `specific_lane` broadcast";
2549  return success();
2550  }
2551 }
2552 
2553 //===----------------------------------------------------------------------===//
2554 // GPU KernelMetadataAttr
2555 //===----------------------------------------------------------------------===//
2556 
2557 KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2558  DictionaryAttr metadata) {
2559  assert(kernel && "invalid kernel");
2560  return get(kernel.getNameAttr(), kernel.getFunctionType(),
2561  kernel.getAllArgAttrs(), metadata);
2562 }
2563 
2564 KernelMetadataAttr
2565 KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2566  FunctionOpInterface kernel,
2567  DictionaryAttr metadata) {
2568  assert(kernel && "invalid kernel");
2569  return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2570  kernel.getAllArgAttrs(), metadata);
2571 }
2572 
2573 KernelMetadataAttr
2574 KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2575  if (attrs.empty())
2576  return *this;
2577  NamedAttrList attrList;
2578  if (DictionaryAttr dict = getMetadata())
2579  attrList.append(dict);
2580  attrList.append(attrs);
2581  return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2582  attrList.getDictionary(getContext()));
2583 }
2584 
2585 LogicalResult
2586 KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2587  StringAttr name, Type functionType,
2588  ArrayAttr argAttrs, DictionaryAttr metadata) {
2589  if (name.empty())
2590  return emitError() << "the kernel name can't be empty";
2591  if (argAttrs) {
2592  if (llvm::any_of(argAttrs, [](Attribute attr) {
2593  return !llvm::isa<DictionaryAttr>(attr);
2594  }))
2595  return emitError()
2596  << "all attributes in the array must be a dictionary attribute";
2597  }
2598  return success();
2599 }
2600 
2601 //===----------------------------------------------------------------------===//
2602 // GPU KernelTableAttr
2603 //===----------------------------------------------------------------------===//
2604 
2605 KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2606  ArrayRef<KernelMetadataAttr> kernels,
2607  bool isSorted) {
2608  // Note that `is_sorted` is always only invoked once even with assertions ON.
2609  assert((!isSorted || llvm::is_sorted(kernels)) &&
2610  "expected a sorted kernel array");
2611  // Immediately return the attribute if the array is sorted.
2612  if (isSorted || llvm::is_sorted(kernels))
2613  return Base::get(context, kernels);
2614  // Sort the array.
2615  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2616  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2617  return Base::get(context, kernelsTmp);
2618 }
2619 
2620 KernelTableAttr KernelTableAttr::getChecked(
2621  function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2622  ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2623  // Note that `is_sorted` is always only invoked once even with assertions ON.
2624  assert((!isSorted || llvm::is_sorted(kernels)) &&
2625  "expected a sorted kernel array");
2626  // Immediately return the attribute if the array is sorted.
2627  if (isSorted || llvm::is_sorted(kernels))
2628  return Base::getChecked(emitError, context, kernels);
2629  // Sort the array.
2630  SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2631  llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2632  return Base::getChecked(emitError, context, kernelsTmp);
2633 }
2634 
2635 LogicalResult
2636 KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2637  ArrayRef<KernelMetadataAttr> kernels) {
2638  if (kernels.size() < 2)
2639  return success();
2640  // Check that the kernels are uniquely named.
2641  if (std::adjacent_find(kernels.begin(), kernels.end(),
2642  [](KernelMetadataAttr l, KernelMetadataAttr r) {
2643  return l.getName() == r.getName();
2644  }) != kernels.end()) {
2645  return emitError() << "expected all kernels to be uniquely named";
2646  }
2647  return success();
2648 }
2649 
2650 KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2651  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2652  return found ? *iterator : KernelMetadataAttr();
2653 }
2654 
2655 KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2656  auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2657  return found ? *iterator : KernelMetadataAttr();
2658 }
2659 
2660 //===----------------------------------------------------------------------===//
2661 // GPU target options
2662 //===----------------------------------------------------------------------===//
2663 
2664 TargetOptions::TargetOptions(
2665  StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2666  StringRef cmdOptions, StringRef elfSection,
2667  CompilationTarget compilationTarget,
2668  function_ref<SymbolTable *()> getSymbolTableCallback,
2669  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2670  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2671  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2672  function_ref<void(StringRef)> isaCallback)
2673  : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2674  cmdOptions, elfSection, compilationTarget,
2675  getSymbolTableCallback, initialLlvmIRCallback,
2676  linkedLlvmIRCallback, optimizedLlvmIRCallback,
2677  isaCallback) {}
2678 
2679 TargetOptions::TargetOptions(
2680  TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2681  StringRef cmdOptions, StringRef elfSection,
2682  CompilationTarget compilationTarget,
2683  function_ref<SymbolTable *()> getSymbolTableCallback,
2684  function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2685  function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2686  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2687  function_ref<void(StringRef)> isaCallback)
2688  : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2689  cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2690  compilationTarget(compilationTarget),
2691  getSymbolTableCallback(getSymbolTableCallback),
2692  initialLlvmIRCallback(initialLlvmIRCallback),
2693  linkedLlvmIRCallback(linkedLlvmIRCallback),
2694  optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2695  isaCallback(isaCallback), typeID(typeID) {}
2696 
2697 TypeID TargetOptions::getTypeID() const { return typeID; }
2698 
2699 StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2700 
2701 ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2702  return librariesToLink;
2703 }
2704 
2705 StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2706 
2707 StringRef TargetOptions::getELFSection() const { return elfSection; }
2708 
2709 SymbolTable *TargetOptions::getSymbolTable() const {
2710  return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2711 }
2712 
2713 function_ref<void(llvm::Module &)>
2714 TargetOptions::getInitialLlvmIRCallback() const {
2715  return initialLlvmIRCallback;
2716 }
2717 
2718 function_ref<void(llvm::Module &)>
2719 TargetOptions::getLinkedLlvmIRCallback() const {
2720  return linkedLlvmIRCallback;
2721 }
2722 
2723 function_ref<void(llvm::Module &)>
2724 TargetOptions::getOptimizedLlvmIRCallback() const {
2725  return optimizedLlvmIRCallback;
2726 }
2727 
2728 function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2729  return isaCallback;
2730 }
2731 
2732 CompilationTarget TargetOptions::getCompilationTarget() const {
2733  return compilationTarget;
2734 }
2735 
2736 CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2737  return CompilationTarget::Fatbin;
2738 }
2739 
2740 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2741 TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2742  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2743  llvm::StringSaver stringSaver(options.first);
2744  StringRef opts = cmdOptions;
2745  // For a correct tokenization of the command line options `opts` must be
2746  // unquoted, otherwise the tokenization function returns a single string: the
2747  // unquoted `cmdOptions` -which is not the desired behavior.
2748  // Remove any quotes if they are at the beginning and end of the string:
2749  if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2750  opts.consume_front("\""), opts.consume_back("\"");
2751  if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2752  opts.consume_front("'"), opts.consume_back("'");
2753 #ifdef _WIN32
2754  llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2755  /*MarkEOLs=*/false);
2756 #else
2757  llvm::cl::TokenizeGNUCommandLine(opts, stringSaver, options.second,
2758  /*MarkEOLs=*/false);
2759 #endif // _WIN32
2760  return options;
2761 }
2762 
2763 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2766 }
2767 
2768 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2770  size_t startPos = cmdOptions.find(startsWith);
2771  if (startPos == std::string::npos)
2772  return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2773 
2774  auto tokenized =
2775  tokenizeCmdOptions(cmdOptions.substr(startPos + startsWith.size()));
2776  cmdOptions.resize(startPos);
2777  return tokenized;
2778 }
2779 
2781 
2782 #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2783 #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2784 
2785 #define GET_ATTRDEF_CLASSES
2786 #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2787 
2788 #define GET_OP_CLASSES
2789 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2790 
2791 #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, OperandRange operands, TypeRange types)
static ParseResult parseAsyncDependencies(OpAsmParser &parser, Type &asyncTokenType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &asyncDependencies)
Parses an optional list of async operands with an optional leading keyword.
Definition: GPUDialect.cpp:491
static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr)
Definition: GPUDialect.cpp:654
static void setAttributionAttrs(GPUFuncOp op, unsigned index, DictionaryAttr value, StringAttr attrName)
static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, Type asyncTokenType, OperandRange asyncDependencies)
Prints optional async dependencies with its leading keyword.
Definition: GPUDialect.cpp:507
static ParseResult parseOffloadingHandler(OpAsmParser &parser, Attribute &offloadingHandler)
static ParseResult parseSizeAssignment(OpAsmParser &parser, MutableArrayRef< OpAsmParser::UnresolvedOperand > sizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > regionSizes, MutableArrayRef< OpAsmParser::UnresolvedOperand > indices)
Definition: GPUDialect.cpp:988
static void printAttributions(OpAsmPrinter &p, StringRef keyword, ArrayRef< BlockArgument > values)
Prints a GPU function memory attribution.
Definition: GPUDialect.cpp:539
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, StringAttr attrName)
static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, Value clusterValue, Type clusterXTy, Type clusterYTy, Type clusterZTy)
static bool canMakeGroupOpUniform(Operation *op)
Definition: GPUDialect.cpp:632
static std::string getSparseHandleKeyword(SparseHandleKind kind)
Definition: GPUDialect.cpp:291
static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, NamedAttribute attr)
Definition: GPUDialect.cpp:381
static void printAllReduceOperation(AsmPrinter &printer, Operation *op, AllReduceOperationAttr attr)
Definition: GPUDialect.cpp:667
static ParseResult parseAttributions(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::Argument > &args)
Parses a GPU function memory attribution.
Definition: GPUDialect.cpp:528
static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy, std::optional< OpAsmParser::UnresolvedOperand > clusterValue, Type &clusterXTy, Type &clusterYTy, Type &clusterZTy)
static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, Attribute value, StringAttr attrsName)
static ParseResult parseLaunchFuncOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &argNames, SmallVectorImpl< Type > &argTypes)
static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, Attribute offloadingHandler)
static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, Type resType)
Definition: GPUDialect.cpp:578
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, KernelDim3 operands, KernelDim3 ids)
Definition: GPUDialect.cpp:923
static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, StringAttr attrsName)
static LogicalResult verifyAttributions(Operation *op, ArrayRef< BlockArgument > attributions, gpu::AddressSpace memorySpace)
Verifies a GPU function memory attribution.
Definition: GPUDialect.cpp:552
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
#define MINUI(lhs, rhs)
static sycl::kernel * getKernel(ze_module_handle_t zeModule, const char *name)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalString(std::string *string)=0
Parse a quoted string token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
IndexType getIndexType()
Definition: Builders.cpp:50
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:103
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:93
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual size_t getNumResults() const =0
Return the number of declared SSA results.
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
Definition: Operation.cpp:255
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isIndex() const
Definition: Types.cpp:54
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConcreteT get(MLIRContext *ctx, Args &&...args)
Get or create a new ConcreteT instance within the ctx.
ImplType * getImpl() const
Utility for easy access to the storage instance.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
Definition: GPUDialect.cpp:202
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:187
Type getElementType() const
Get elementType of a single element.
Definition: GPUDialect.cpp:206
static bool isValidElementType(Type elementType)
Check if a type is valid a MMAMatrixType elementType.
Definition: GPUDialect.cpp:210
static LogicalResult verifyInvariants(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition: GPUDialect.cpp:217
StringRef getOperand() const
The general form of operation this type supports is given by the equation C += A*B.
Definition: GPUDialect.cpp:208
static MMAMatrixType getChecked(function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition: GPUDialect.cpp:193
unsigned getNumDims() const
Get number of dims.
Definition: GPUDialect.cpp:200
This class serves as an opaque interface for passing options to the TargetAttrInterface methods.
std::string cmdOptions
An optional set of command line options to be used by the compilation process.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeCmdOptions() const
Returns a tokenization of the command line options.
std::pair< llvm::BumpPtrAllocator, SmallVector< const char * > > tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith)
Returns a tokenization of the substr of the command line options that starts with startsWith and ends...
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::unique_function< InFlightDiagnostic()> getDefaultDiagnosticEmitFn(MLIRContext *ctx)
Utility method to generate a callback that can be used to generate a diagnostic when checking the con...
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
void addAsyncDependency(Operation *op, Value token)
Definition: GPUDialect.cpp:731
llvm::StringMap< llvm::SmallString< 8 > > dictionary
A dictionary stores a mapping of template variable names to their assigned string values.
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
Simplify the gpu.launch when the range of a thread or block ID is trivially known to be one.
LogicalResult matchAndRewrite(LaunchOp op, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Utility class for the GPU dialect to represent triples of Values accessible through ....
Definition: GPUDialect.h:39