MLIR 23.0.0git
OpenACCUtilsCG.cpp
Go to the documentation of this file.
1//===- OpenACCUtilsCG.cpp - OpenACC Code Generation Utilities -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility functions for OpenACC code generation.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/BuiltinOps.h"
17#include "mlir/IR/IRMapping.h"
18#include "llvm/ADT/STLExtras.h"
19
20namespace mlir {
21namespace acc {
22
23std::optional<DataLayout> getDataLayout(Operation *op, bool allowDefault) {
24 if (!op)
25 return std::nullopt;
26
27 // Walk up the parent chain to find the nearest operation with an explicit
28 // data layout spec. Check ModuleOp explicitly since it does not actually
29 // implement DataLayoutOpInterface as a trait (it just has the same methods).
30 Operation *current = op;
31 while (current) {
32 // Check for ModuleOp with explicit data layout spec
33 if (auto mod = llvm::dyn_cast<ModuleOp>(current)) {
34 if (mod.getDataLayoutSpec())
35 return DataLayout(mod);
36 } else if (auto dataLayoutOp =
37 llvm::dyn_cast<DataLayoutOpInterface>(current)) {
38 // Check other DataLayoutOpInterface implementations
39 if (dataLayoutOp.getDataLayoutSpec())
40 return DataLayout(dataLayoutOp);
41 }
42 current = current->getParentOp();
43 }
44
45 // No explicit data layout found; return default if allowed
46 if (allowDefault) {
47 // Check if op itself is a ModuleOp
48 if (auto mod = llvm::dyn_cast<ModuleOp>(op))
49 return DataLayout(mod);
50 // Otherwise check parents
51 if (auto mod = op->getParentOfType<ModuleOp>())
52 return DataLayout(mod);
53 }
54
55 return std::nullopt;
56}
57
58ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs,
59 ValueRange inputArgs, llvm::StringRef origin,
60 Region &regionToClone,
61 RewriterBase &rewriter, IRMapping &mapping,
62 ValueRange output,
63 FlatSymbolRefAttr kernelFuncName,
64 FlatSymbolRefAttr kernelModuleName,
65 Value stream, ValueRange inputArgsToMap) {
66 SmallVector<Type> resultTypes;
67 for (auto val : output)
68 resultTypes.push_back(val.getType());
69 auto computeRegion =
70 ComputeRegionOp::create(rewriter, loc, resultTypes, launchArgs, inputArgs,
71 stream, origin, kernelFuncName, kernelModuleName);
72
73 assert(!regionToClone.getBlocks().empty() &&
74 "empty region for acc.compute_region");
75 OpBuilder::InsertionGuard guard(rewriter);
76
77 ValueRange mapKeys = inputArgsToMap.empty() ? inputArgs : inputArgsToMap;
78 assert(mapKeys.size() == inputArgs.size() &&
79 "inputArgsToMap must have same size as inputArgs when provided");
80
81 Type indexType = rewriter.getIndexType();
82 Block *entryBlock = rewriter.createBlock(&computeRegion.getRegion());
83 for (size_t i = 0; i < launchArgs.size(); ++i)
84 entryBlock->addArgument(indexType, loc);
85 for (Value input : inputArgs)
86 entryBlock->addArgument(input.getType(), loc);
87 for (size_t i = 0; i < inputArgs.size(); ++i)
88 mapping.map(mapKeys[i], entryBlock->getArgument(launchArgs.size() + i));
89 rewriter.setInsertionPointToStart(entryBlock);
90 if (regionToClone.getBlocks().size() == 1) {
91 for (auto &op : regionToClone.front().getOperations()) {
92 if (op.hasTrait<OpTrait::IsTerminator>())
93 break;
94 rewriter.clone(op, mapping);
95 }
96 SmallVector<Value> yieldOperands;
97 for (auto val : output)
98 yieldOperands.push_back(mapping.lookup(val));
99 rewriter.setInsertionPointToEnd(entryBlock);
100 YieldOp::create(rewriter, loc, yieldOperands);
101 } else {
103 regionToClone, mapping, loc, rewriter);
104 if (!exeRegion) {
105 rewriter.eraseOp(computeRegion);
106 return nullptr;
107 }
109 llvm::to_vector(exeRegion.getOps<scf::YieldOp>()));
110 assert(!yieldOps.empty() &&
111 "multi-block region must contain at least one scf.yield");
112 assert(llvm::all_of(yieldOps,
113 [&output](scf::YieldOp yieldOp) {
114 return yieldOp.getNumOperands() ==
115 static_cast<int64_t>(output.size()) &&
116 llvm::all_of(
117 llvm::zip(yieldOp.getOperands(), output),
118 [](auto pair) {
119 return std::get<0>(pair).getType() ==
120 std::get<1>(pair).getType();
121 });
122 }) &&
123 "each scf.yield operand count and types must match output");
124 rewriter.setInsertionPointToEnd(entryBlock);
125 YieldOp::create(rewriter, loc, exeRegion.getResults());
126 }
127
128 return computeRegion;
129}
130
131} // namespace acc
132} // namespace mlir
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
OpListType & getOperations()
Definition Block.h:147
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
IndexType getIndexType()
Definition Builders.cpp:55
The main mechanism for performing data layout queries.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockListType & getBlocks()
Definition Region.h:45
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
std::optional< DataLayout > getDataLayout(Operation *op, bool allowDefault=true)
Get the data layout for an operation.
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region &regionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
scf::ExecuteRegionOp wrapMultiBlockRegionWithSCFExecuteRegion(Region &region, IRMapping &mapping, Location loc, RewriterBase &rewriter)
Wrap a multi-block region in an scf.execute_region.
Include the generated interface declarations.