MLIR  14.0.0git
OpenACC.cpp
Go to the documentation of this file.
1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
11 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace acc;
22 
23 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // OpenACC operations
27 //===----------------------------------------------------------------------===//
28 
29 void OpenACCDialect::initialize() {
30  addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
33  >();
34  addAttributes<
35 #define GET_ATTRDEF_LIST
36 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
37  >();
38 }
39 
40 template <typename StructureOp>
42  unsigned nRegions = 1) {
43 
45  for (unsigned i = 0; i < nRegions; ++i)
46  regions.push_back(state.addRegion());
47 
48  for (Region *region : regions) {
49  if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
50  return failure();
51  }
52 
53  return success();
54 }
55 
56 static ParseResult
57 parseOperandList(OpAsmParser &parser, StringRef keyword,
59  SmallVectorImpl<Type> &argTypes, OperationState &result) {
60  if (failed(parser.parseOptionalKeyword(keyword)))
61  return success();
62 
63  if (failed(parser.parseLParen()))
64  return failure();
65 
66  // Exit early if the list is empty.
67  if (succeeded(parser.parseOptionalRParen()))
68  return success();
69 
70  do {
72  Type type;
73 
74  if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
75  return failure();
76 
77  args.push_back(arg);
78  argTypes.push_back(type);
79  } while (succeeded(parser.parseOptionalComma()));
80 
81  if (failed(parser.parseRParen()))
82  return failure();
83 
84  return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
85  result.operands);
86 }
87 
89  StringRef listName, OpAsmPrinter &printer) {
90 
91  if (!operands.empty()) {
92  printer << " " << listName << "(";
93  llvm::interleaveComma(operands, printer, [&](Value op) {
94  printer << op << ": " << op.getType();
95  });
96  printer << ")";
97  }
98 }
99 
101  OpAsmParser::OperandType &operand,
102  Type type, bool &hasOptional,
103  OperationState &result) {
104  hasOptional = false;
105  if (succeeded(parser.parseOptionalKeyword(keyword))) {
106  hasOptional = true;
107  if (parser.parseLParen() || parser.parseOperand(operand) ||
108  parser.resolveOperand(operand, type, result.operands) ||
109  parser.parseRParen())
110  return failure();
111  }
112  return success();
113 }
114 
116  OperationState &result) {
117  OpAsmParser::OperandType operand;
118  Type type;
119  if (parser.parseOperand(operand) || parser.parseColonType(type) ||
120  parser.resolveOperand(operand, type, result.operands))
121  return failure();
122  return success();
123 }
124 
125 /// Parse optional operand and its type wrapped in parenthesis prefixed with
126 /// a keyword.
127 /// Example:
128 /// keyword `(` %vectorLength: i64 `)`
130  StringRef keyword,
131  OperationState &result) {
132  OpAsmParser::OperandType operand;
133  if (succeeded(parser.parseOptionalKeyword(keyword))) {
134  return failure(parser.parseLParen() ||
135  parseOperandAndType(parser, result) || parser.parseRParen());
136  }
137  return llvm::None;
138 }
139 
140 /// Parse optional operand and its type wrapped in parenthesis.
141 /// Example:
142 /// `(` %vectorLength: i64 `)`
144  OperationState &result) {
145  if (succeeded(parser.parseOptionalLParen())) {
146  return failure(parseOperandAndType(parser, result) || parser.parseRParen());
147  }
148  return llvm::None;
149 }
150 
151 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
152 /// Example:
153 /// num=%gangNum: i32
155  OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
156  if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
157  parser.parseEqual();
158  return parseOperandAndType(parser, result);
159  }
160  return llvm::None;
161 }
162 
163 static bool isComputeOperation(Operation *op) {
164  return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
165 }
166 
167 namespace {
168 /// Pattern to remove operation without region that have constant false `ifCond`
169 /// and remove the condition from the operation if the `ifCond` is a true
170 /// constant.
171 template <typename OpTy>
172 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
174 
175  LogicalResult matchAndRewrite(OpTy op,
176  PatternRewriter &rewriter) const override {
177  // Early return if there is no condition.
178  if (!op.ifCond())
179  return success();
180 
181  auto constOp = op.ifCond().template getDefiningOp<arith::ConstantOp>();
182  if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt())
183  rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
184  else if (constOp)
185  rewriter.eraseOp(op);
186 
187  return success();
188  }
189 };
190 } // namespace
191 
192 //===----------------------------------------------------------------------===//
193 // ParallelOp
194 //===----------------------------------------------------------------------===//
195 
196 /// Parse acc.parallel operation
197 /// operation := `acc.parallel` `async` `(` index `)`?
198 /// `wait` `(` index-list `)`?
199 /// `num_gangs` `(` value `)`?
200 /// `num_workers` `(` value `)`?
201 /// `vector_length` `(` value `)`?
202 /// `if` `(` value `)`?
203 /// `self` `(` value `)`?
204 /// `reduction` `(` value-list `)`?
205 /// `copy` `(` value-list `)`?
206 /// `copyin` `(` value-list `)`?
207 /// `copyin_readonly` `(` value-list `)`?
208 /// `copyout` `(` value-list `)`?
209 /// `copyout_zero` `(` value-list `)`?
210 /// `create` `(` value-list `)`?
211 /// `create_zero` `(` value-list `)`?
212 /// `no_create` `(` value-list `)`?
213 /// `present` `(` value-list `)`?
214 /// `deviceptr` `(` value-list `)`?
215 /// `attach` `(` value-list `)`?
216 /// `private` `(` value-list `)`?
217 /// `firstprivate` `(` value-list `)`?
218 /// region attr-dict?
220  OperationState &result) {
221  Builder &builder = parser.getBuilder();
223  firstprivateOperands, copyOperands, copyinOperands,
224  copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
225  createOperands, createZeroOperands, noCreateOperands, presentOperands,
226  devicePtrOperands, attachOperands, waitOperands, reductionOperands;
227  SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
228  copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
229  copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
230  createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
231  deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
232  firstprivateOperandTypes;
233 
234  SmallVector<Type, 8> operandTypes;
235  OpAsmParser::OperandType ifCond, selfCond;
236  bool hasIfCond = false, hasSelfCond = false;
237  OptionalParseResult async, numGangs, numWorkers, vectorLength;
238  Type i1Type = builder.getI1Type();
239 
240  // async()?
241  async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
242  result);
243  if (async.hasValue() && failed(*async))
244  return failure();
245 
246  // wait()?
247  if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
248  waitOperands, waitOperandTypes, result)))
249  return failure();
250 
251  // num_gangs(value)?
252  numGangs = parseOptionalOperandAndType(
253  parser, ParallelOp::getNumGangsKeyword(), result);
254  if (numGangs.hasValue() && failed(*numGangs))
255  return failure();
256 
257  // num_workers(value)?
258  numWorkers = parseOptionalOperandAndType(
259  parser, ParallelOp::getNumWorkersKeyword(), result);
260  if (numWorkers.hasValue() && failed(*numWorkers))
261  return failure();
262 
263  // vector_length(value)?
264  vectorLength = parseOptionalOperandAndType(
265  parser, ParallelOp::getVectorLengthKeyword(), result);
266  if (vectorLength.hasValue() && failed(*vectorLength))
267  return failure();
268 
269  // if()?
270  if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
271  i1Type, hasIfCond, result)))
272  return failure();
273 
274  // self()?
275  if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
276  selfCond, i1Type, hasSelfCond, result)))
277  return failure();
278 
279  // reduction()?
280  if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
281  reductionOperands, reductionOperandTypes,
282  result)))
283  return failure();
284 
285  // copy()?
286  if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
287  copyOperands, copyOperandTypes, result)))
288  return failure();
289 
290  // copyin()?
291  if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
292  copyinOperands, copyinOperandTypes, result)))
293  return failure();
294 
295  // copyin_readonly()?
296  if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
297  copyinReadonlyOperands,
298  copyinReadonlyOperandTypes, result)))
299  return failure();
300 
301  // copyout()?
302  if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
303  copyoutOperands, copyoutOperandTypes, result)))
304  return failure();
305 
306  // copyout_zero()?
307  if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
308  copyoutZeroOperands, copyoutZeroOperandTypes,
309  result)))
310  return failure();
311 
312  // create()?
313  if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
314  createOperands, createOperandTypes, result)))
315  return failure();
316 
317  // create_zero()?
318  if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
319  createZeroOperands, createZeroOperandTypes,
320  result)))
321  return failure();
322 
323  // no_create()?
324  if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
325  noCreateOperands, noCreateOperandTypes, result)))
326  return failure();
327 
328  // present()?
329  if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
330  presentOperands, presentOperandTypes, result)))
331  return failure();
332 
333  // deviceptr()?
334  if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
335  devicePtrOperands, deviceptrOperandTypes,
336  result)))
337  return failure();
338 
339  // attach()?
340  if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
341  attachOperands, attachOperandTypes, result)))
342  return failure();
343 
344  // private()?
345  if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
346  privateOperands, privateOperandTypes, result)))
347  return failure();
348 
349  // firstprivate()?
350  if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
351  firstprivateOperands, firstprivateOperandTypes,
352  result)))
353  return failure();
354 
355  // Parallel op region
356  if (failed(parseRegions<ParallelOp>(parser, result)))
357  return failure();
358 
359  result.addAttribute(
360  ParallelOp::getOperandSegmentSizeAttr(),
361  builder.getI32VectorAttr(
362  {static_cast<int32_t>(async.hasValue() ? 1 : 0),
363  static_cast<int32_t>(waitOperands.size()),
364  static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
365  static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
366  static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
367  static_cast<int32_t>(hasIfCond ? 1 : 0),
368  static_cast<int32_t>(hasSelfCond ? 1 : 0),
369  static_cast<int32_t>(reductionOperands.size()),
370  static_cast<int32_t>(copyOperands.size()),
371  static_cast<int32_t>(copyinOperands.size()),
372  static_cast<int32_t>(copyinReadonlyOperands.size()),
373  static_cast<int32_t>(copyoutOperands.size()),
374  static_cast<int32_t>(copyoutZeroOperands.size()),
375  static_cast<int32_t>(createOperands.size()),
376  static_cast<int32_t>(createZeroOperands.size()),
377  static_cast<int32_t>(noCreateOperands.size()),
378  static_cast<int32_t>(presentOperands.size()),
379  static_cast<int32_t>(devicePtrOperands.size()),
380  static_cast<int32_t>(attachOperands.size()),
381  static_cast<int32_t>(privateOperands.size()),
382  static_cast<int32_t>(firstprivateOperands.size())}));
383 
384  // Additional attributes
386  return failure();
387 
388  return success();
389 }
390 
391 static void print(OpAsmPrinter &printer, ParallelOp &op) {
392  // async()?
393  if (Value async = op.async())
394  printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
395  << async.getType() << ")";
396 
397  // wait()?
398  printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
399 
400  // num_gangs()?
401  if (Value numGangs = op.numGangs())
402  printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
403  << ": " << numGangs.getType() << ")";
404 
405  // num_workers()?
406  if (Value numWorkers = op.numWorkers())
407  printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
408  << ": " << numWorkers.getType() << ")";
409 
410  // vector_length()?
411  if (Value vectorLength = op.vectorLength())
412  printer << " " << ParallelOp::getVectorLengthKeyword() << "("
413  << vectorLength << ": " << vectorLength.getType() << ")";
414 
415  // if()?
416  if (Value ifCond = op.ifCond())
417  printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
418 
419  // self()?
420  if (Value selfCond = op.selfCond())
421  printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
422 
423  // reduction()?
424  printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
425  printer);
426 
427  // copy()?
428  printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
429 
430  // copyin()?
431  printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
432  printer);
433 
434  // copyin_readonly()?
435  printOperandList(op.copyinReadonlyOperands(),
436  ParallelOp::getCopyinReadonlyKeyword(), printer);
437 
438  // copyout()?
439  printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
440  printer);
441 
442  // copyout_zero()?
443  printOperandList(op.copyoutZeroOperands(),
444  ParallelOp::getCopyoutZeroKeyword(), printer);
445 
446  // create()?
447  printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
448  printer);
449 
450  // create_zero()?
451  printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
452  printer);
453 
454  // no_create()?
455  printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
456  printer);
457 
458  // present()?
459  printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
460  printer);
461 
462  // deviceptr()?
463  printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
464  printer);
465 
466  // attach()?
467  printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
468  printer);
469 
470  // private()?
471  printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
472  printer);
473 
474  // firstprivate()?
475  printOperandList(op.gangFirstPrivateOperands(),
476  ParallelOp::getFirstPrivateKeyword(), printer);
477 
478  printer << ' ';
479  printer.printRegion(op.region(),
480  /*printEntryBlockArgs=*/false,
481  /*printBlockTerminators=*/true);
482  printer.printOptionalAttrDictWithKeyword(
483  op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
484 }
485 
486 unsigned ParallelOp::getNumDataOperands() {
487  return reductionOperands().size() + copyOperands().size() +
488  copyinOperands().size() + copyinReadonlyOperands().size() +
489  copyoutOperands().size() + copyoutZeroOperands().size() +
490  createOperands().size() + createZeroOperands().size() +
491  noCreateOperands().size() + presentOperands().size() +
492  devicePtrOperands().size() + attachOperands().size() +
493  gangPrivateOperands().size() + gangFirstPrivateOperands().size();
494 }
495 
496 Value ParallelOp::getDataOperand(unsigned i) {
497  unsigned numOptional = async() ? 1 : 0;
498  numOptional += numGangs() ? 1 : 0;
499  numOptional += numWorkers() ? 1 : 0;
500  numOptional += vectorLength() ? 1 : 0;
501  numOptional += ifCond() ? 1 : 0;
502  numOptional += selfCond() ? 1 : 0;
503  return getOperand(waitOperands().size() + numOptional + i);
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // LoopOp
508 //===----------------------------------------------------------------------===//
509 
510 /// Parse acc.loop operation
511 /// operation := `acc.loop`
512 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
513 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
514 /// (`vector_length` `(` value `)`)?
515 /// (`tile` `(` value-list `)`)?
516 /// (`private` `(` value-list `)`)?
517 /// (`reduction` `(` value-list `)`)?
518 /// region attr-dict?
520  Builder &builder = parser.getBuilder();
521  unsigned executionMapping = OpenACCExecMapping::NONE;
522  SmallVector<Type, 8> operandTypes;
523  SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
525  OptionalParseResult gangNum, gangStatic, worker, vector;
526 
527  // gang?
528  if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
529  executionMapping |= OpenACCExecMapping::GANG;
530 
531  // optional gang operand
532  if (succeeded(parser.parseOptionalLParen())) {
534  parser, result, LoopOp::getGangNumKeyword());
535  if (gangNum.hasValue() && failed(*gangNum))
536  return failure();
537  parser.parseOptionalComma();
539  parser, result, LoopOp::getGangStaticKeyword());
540  if (gangStatic.hasValue() && failed(*gangStatic))
541  return failure();
542  parser.parseOptionalComma();
543  if (failed(parser.parseRParen()))
544  return failure();
545  }
546 
547  // worker?
548  if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
549  executionMapping |= OpenACCExecMapping::WORKER;
550 
551  // optional worker operand
552  worker = parseOptionalOperandAndType(parser, result);
553  if (worker.hasValue() && failed(*worker))
554  return failure();
555 
556  // vector?
557  if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
558  executionMapping |= OpenACCExecMapping::VECTOR;
559 
560  // optional vector operand
561  vector = parseOptionalOperandAndType(parser, result);
562  if (vector.hasValue() && failed(*vector))
563  return failure();
564 
565  // tile()?
566  if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
567  operandTypes, result)))
568  return failure();
569 
570  // private()?
571  if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
572  privateOperands, operandTypes, result)))
573  return failure();
574 
575  // reduction()?
576  if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
577  reductionOperands, operandTypes, result)))
578  return failure();
579 
580  if (executionMapping != acc::OpenACCExecMapping::NONE)
581  result.addAttribute(LoopOp::getExecutionMappingAttrName(),
582  builder.getI64IntegerAttr(executionMapping));
583 
584  // Parse optional results in case there is a reduce.
585  if (parser.parseOptionalArrowTypeList(result.types))
586  return failure();
587 
588  if (failed(parseRegions<LoopOp>(parser, result)))
589  return failure();
590 
591  result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
592  builder.getI32VectorAttr(
593  {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
594  static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
595  static_cast<int32_t>(worker.hasValue() ? 1 : 0),
596  static_cast<int32_t>(vector.hasValue() ? 1 : 0),
597  static_cast<int32_t>(tileOperands.size()),
598  static_cast<int32_t>(privateOperands.size()),
599  static_cast<int32_t>(reductionOperands.size())}));
600 
602  return failure();
603 
604  return success();
605 }
606 
607 static void print(OpAsmPrinter &printer, LoopOp &op) {
608  unsigned execMapping = op.exec_mapping();
609  if (execMapping & OpenACCExecMapping::GANG) {
610  printer << " " << LoopOp::getGangKeyword();
611  Value gangNum = op.gangNum();
612  Value gangStatic = op.gangStatic();
613 
614  // Print optional gang operands
615  if (gangNum || gangStatic) {
616  printer << "(";
617  if (gangNum) {
618  printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
619  << gangNum.getType();
620  if (gangStatic)
621  printer << ", ";
622  }
623  if (gangStatic)
624  printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
625  << gangStatic.getType();
626  printer << ")";
627  }
628  }
629 
630  if (execMapping & OpenACCExecMapping::WORKER) {
631  printer << " " << LoopOp::getWorkerKeyword();
632 
633  // Print optional worker operand if present
634  if (Value workerNum = op.workerNum())
635  printer << "(" << workerNum << ": " << workerNum.getType() << ")";
636  }
637 
638  if (execMapping & OpenACCExecMapping::VECTOR) {
639  printer << " " << LoopOp::getVectorKeyword();
640 
641  // Print optional vector operand if present
642  if (Value vectorLength = op.vectorLength())
643  printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
644  }
645 
646  // tile()?
647  printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
648 
649  // private()?
650  printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
651 
652  // reduction()?
653  printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
654  printer);
655 
656  if (op.getNumResults() > 0)
657  printer << " -> (" << op.getResultTypes() << ")";
658 
659  printer << ' ';
660  printer.printRegion(op.region(),
661  /*printEntryBlockArgs=*/false,
662  /*printBlockTerminators=*/true);
663 
665  op->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
666  LoopOp::getOperandSegmentSizeAttr()});
667 }
668 
669 static LogicalResult verifyLoopOp(acc::LoopOp loopOp) {
670  // auto, independent and seq attribute are mutually exclusive.
671  if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) ||
672  (loopOp.independent() && loopOp.seq())) {
673  loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
674  acc::LoopOp::getIndependentAttrName() + ", " +
675  acc::LoopOp::getSeqAttrName() +
676  " can be present at the same time");
677  return failure();
678  }
679 
680  // Gang, worker and vector are incompatible with seq.
681  if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) {
682  loopOp.emitError("gang, worker or vector cannot appear with the seq attr");
683  return failure();
684  }
685 
686  // Check non-empty body().
687  if (loopOp.region().empty()) {
688  loopOp.emitError("expected non-empty body.");
689  return failure();
690  }
691 
692  return success();
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // DataOp
697 //===----------------------------------------------------------------------===//
698 
699 static LogicalResult verify(acc::DataOp dataOp) {
700  // 2.6.5. Data Construct restriction
701  // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
702  // attach, or default clause must appear on a data construct.
703  if (dataOp.getOperands().empty() && !dataOp.defaultAttr())
704  return dataOp.emitError("at least one operand or the default attribute "
705  "must appear on the data operation");
706  return success();
707 }
708 
709 unsigned DataOp::getNumDataOperands() {
710  return copyOperands().size() + copyinOperands().size() +
711  copyinReadonlyOperands().size() + copyoutOperands().size() +
712  copyoutZeroOperands().size() + createOperands().size() +
713  createZeroOperands().size() + noCreateOperands().size() +
714  presentOperands().size() + deviceptrOperands().size() +
715  attachOperands().size();
716 }
717 
718 Value DataOp::getDataOperand(unsigned i) {
719  unsigned numOptional = ifCond() ? 1 : 0;
720  return getOperand(numOptional + i);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // ExitDataOp
725 //===----------------------------------------------------------------------===//
726 
727 static LogicalResult verify(acc::ExitDataOp op) {
728  // 2.6.6. Data Exit Directive restriction
729  // At least one copyout, delete, or detach clause must appear on an exit data
730  // directive.
731  if (op.copyoutOperands().empty() && op.deleteOperands().empty() &&
732  op.detachOperands().empty())
733  return op.emitError(
734  "at least one operand in copyout, delete or detach must appear on the "
735  "exit data operation");
736 
737  // The async attribute represent the async clause without value. Therefore the
738  // attribute and operand cannot appear at the same time.
739  if (op.asyncOperand() && op.async())
740  return op.emitError("async attribute cannot appear with asyncOperand");
741 
742  // The wait attribute represent the wait clause without values. Therefore the
743  // attribute and operands cannot appear at the same time.
744  if (!op.waitOperands().empty() && op.wait())
745  return op.emitError("wait attribute cannot appear with waitOperands");
746 
747  if (op.waitDevnum() && op.waitOperands().empty())
748  return op.emitError("wait_devnum cannot appear without waitOperands");
749 
750  return success();
751 }
752 
753 unsigned ExitDataOp::getNumDataOperands() {
754  return copyoutOperands().size() + deleteOperands().size() +
755  detachOperands().size();
756 }
757 
758 Value ExitDataOp::getDataOperand(unsigned i) {
759  unsigned numOptional = ifCond() ? 1 : 0;
760  numOptional += asyncOperand() ? 1 : 0;
761  numOptional += waitDevnum() ? 1 : 0;
762  return getOperand(waitOperands().size() + numOptional + i);
763 }
764 
765 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
766  MLIRContext *context) {
767  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // EnterDataOp
772 //===----------------------------------------------------------------------===//
773 
774 static LogicalResult verify(acc::EnterDataOp op) {
775  // 2.6.6. Data Enter Directive restriction
776  // At least one copyin, create, or attach clause must appear on an enter data
777  // directive.
778  if (op.copyinOperands().empty() && op.createOperands().empty() &&
779  op.createZeroOperands().empty() && op.attachOperands().empty())
780  return op.emitError(
781  "at least one operand in copyin, create, "
782  "create_zero or attach must appear on the enter data operation");
783 
784  // The async attribute represent the async clause without value. Therefore the
785  // attribute and operand cannot appear at the same time.
786  if (op.asyncOperand() && op.async())
787  return op.emitError("async attribute cannot appear with asyncOperand");
788 
789  // The wait attribute represent the wait clause without values. Therefore the
790  // attribute and operands cannot appear at the same time.
791  if (!op.waitOperands().empty() && op.wait())
792  return op.emitError("wait attribute cannot appear with waitOperands");
793 
794  if (op.waitDevnum() && op.waitOperands().empty())
795  return op.emitError("wait_devnum cannot appear without waitOperands");
796 
797  return success();
798 }
799 
800 unsigned EnterDataOp::getNumDataOperands() {
801  return copyinOperands().size() + createOperands().size() +
802  createZeroOperands().size() + attachOperands().size();
803 }
804 
805 Value EnterDataOp::getDataOperand(unsigned i) {
806  unsigned numOptional = ifCond() ? 1 : 0;
807  numOptional += asyncOperand() ? 1 : 0;
808  numOptional += waitDevnum() ? 1 : 0;
809  return getOperand(waitOperands().size() + numOptional + i);
810 }
811 
812 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
813  MLIRContext *context) {
814  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // InitOp
819 //===----------------------------------------------------------------------===//
820 
821 static LogicalResult verify(acc::InitOp initOp) {
822  Operation *currOp = initOp;
823  while ((currOp = currOp->getParentOp())) {
824  if (isComputeOperation(currOp))
825  return initOp.emitOpError("cannot be nested in a compute operation");
826  }
827  return success();
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // ShutdownOp
832 //===----------------------------------------------------------------------===//
833 
834 static LogicalResult verify(acc::ShutdownOp op) {
835  Operation *currOp = op;
836  while ((currOp = currOp->getParentOp())) {
837  if (isComputeOperation(currOp))
838  return op.emitOpError("cannot be nested in a compute operation");
839  }
840  return success();
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // UpdateOp
845 //===----------------------------------------------------------------------===//
846 
847 static LogicalResult verify(acc::UpdateOp updateOp) {
848  // At least one of host or device should have a value.
849  if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty())
850  return updateOp.emitError("at least one value must be present in"
851  " hostOperands or deviceOperands");
852 
853  // The async attribute represent the async clause without value. Therefore the
854  // attribute and operand cannot appear at the same time.
855  if (updateOp.asyncOperand() && updateOp.async())
856  return updateOp.emitError("async attribute cannot appear with "
857  " asyncOperand");
858 
859  // The wait attribute represent the wait clause without values. Therefore the
860  // attribute and operands cannot appear at the same time.
861  if (!updateOp.waitOperands().empty() && updateOp.wait())
862  return updateOp.emitError("wait attribute cannot appear with waitOperands");
863 
864  if (updateOp.waitDevnum() && updateOp.waitOperands().empty())
865  return updateOp.emitError("wait_devnum cannot appear without waitOperands");
866 
867  return success();
868 }
869 
870 unsigned UpdateOp::getNumDataOperands() {
871  return hostOperands().size() + deviceOperands().size();
872 }
873 
874 Value UpdateOp::getDataOperand(unsigned i) {
875  unsigned numOptional = asyncOperand() ? 1 : 0;
876  numOptional += waitDevnum() ? 1 : 0;
877  numOptional += ifCond() ? 1 : 0;
878  return getOperand(waitOperands().size() + deviceTypeOperands().size() +
879  numOptional + i);
880 }
881 
882 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
883  MLIRContext *context) {
884  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // WaitOp
889 //===----------------------------------------------------------------------===//
890 
891 static LogicalResult verify(acc::WaitOp waitOp) {
892  // The async attribute represent the async clause without value. Therefore the
893  // attribute and operand cannot appear at the same time.
894  if (waitOp.asyncOperand() && waitOp.async())
895  return waitOp.emitError("async attribute cannot appear with asyncOperand");
896 
897  if (waitOp.waitDevnum() && waitOp.waitOperands().empty())
898  return waitOp.emitError("wait_devnum cannot appear without waitOperands");
899 
900  return success();
901 }
902 
903 #define GET_OP_CLASSES
904 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
905 
906 #define GET_ATTRDEF_CLASSES
907 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
This is the representation of an operand reference.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseLParen()=0
Parse a ( token.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
virtual ParseResult parseRegionArgument(OperandType &argument)=0
Parse a region argument, this argument is resolved when calling &#39;parseRegion&#39;.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition: OpenACC.cpp:41
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
static void printOperandList(Operation::operand_range operands, StringRef listName, OpAsmPrinter &printer)
Definition: OpenACC.cpp:88
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result)
Parse acc.loop operation operation := acc.loop (gang ( ( (num= value)? (, static= value ))...
Definition: OpenACC.cpp:519
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
static LogicalResult verifyLoopOp(acc::LoopOp loopOp)
Definition: OpenACC.cpp:669
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
SmallVector< Value, 4 > operands
static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, StringRef keyword, OperationState &result)
Parse optional operand and its type wrapped in parenthesis prefixed with a keyword.
Definition: OpenACC.cpp:129
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual ParseResult resolveOperand(const OperandType &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static ParseResult parseOperandList(OpAsmParser &parser, StringRef keyword, SmallVectorImpl< OpAsmParser::OperandType > &args, SmallVectorImpl< Type > &argTypes, OperationState &result)
Definition: OpenACC.cpp:57
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
static bool isComputeOperation(Operation *op)
Definition: OpenACC.cpp:163
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
virtual ParseResult parseRParen()=0
Parse a ) token.
IntegerType getI1Type()
Definition: Builders.cpp:50
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
This represents an operation in an abstracted form, suitable for use with the builder APIs...
static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, OpAsmParser::OperandType &operand, Type type, bool &hasOptional, OperationState &result)
Definition: OpenACC.cpp:100
static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(OpAsmParser &parser, OperationState &result, StringRef prefixKeyword)
Parse optional operand with its type prefixed with prefixKeyword =.
Definition: OpenACC.cpp:154
static void print(OpAsmPrinter &printer, ParallelOp &op)
Definition: OpenACC.cpp:391
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static ParseResult parseOperandAndType(OpAsmParser &parser, OperationState &result)
Definition: OpenACC.cpp:115
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
static ParseResult parseParallelOp(OpAsmParser &parser, OperationState &result)
Parse acc.parallel operation operation := acc.parallel async ( index )? wait ( index-list )...
Definition: OpenACC.cpp:219
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Region * addRegion()
Create a region that should be attached to the operation.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:109
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
This class implements the operand iterators for the Operation class.
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
virtual ParseResult parseEqual()=0
Parse a = token.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "&#39;dim&#39; op " which is convenient for verifiers...
Definition: Operation.cpp:518
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.