58#include "llvm/ADT/STLExtras.h"
59#include "llvm/ADT/TypeSwitch.h"
60#include "llvm/Support/Debug.h"
61#include "llvm/Support/ErrorHandling.h"
65#define GEN_PASS_DEF_ACCRECIPEMATERIALIZATION
66#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
70#define DEBUG_TYPE "acc-recipe-materialization"
76static void saveVarName(StringRef name,
Value dst) {
82 if (isa<ACC_DATA_ENTRY_OPS>(dstOp))
85 acc::VarNameAttr::get(dstOp->getContext(), name));
88 auto blockArg = dyn_cast<BlockArgument>(dst);
91 Block *block = blockArg.getOwner();
98 auto funcOp = dyn_cast<FunctionOpInterface>(parent);
101 unsigned argIdx = blockArg.getArgNumber();
102 if (argIdx >= funcOp.getNumArguments())
107 acc::VarNameAttr::get(parent->
getContext(), name));
110static void saveVarName(
Value src,
Value dst) {
117 for (
auto it = block->
begin(); it != std::next(ip); ++it) {
119 if (attr && attr.getName() == placeholder) {
124 acc::VarNameAttr::get(it->getContext(), name));
132template <
typename RecipeOpTy>
133static void cloneDestroy(RecipeOpTy recipe,
mlir::Block *block,
137 Region &destroyRegion = recipe.getDestroyRegion();
138 assert(destroyRegion.
getBlocks().front().getNumArguments() ==
140 "unexpected acc recipe destroy block arguments");
141 mapping.
map(destroyRegion.
getBlocks().front().getArguments(), arguments);
146class ACCRecipeMaterialization
150 ACCRecipeMaterialization>::ACCRecipeMaterializationBase;
151 void runOnOperation()
override;
167 void handleFirstprivateMapping(acc::FirstprivateOp firstprivateOp)
const;
168 template <
typename OpTy>
169 void removeRecipe(OpTy op, ModuleOp moduleOp)
const;
170 template <
typename OpTy,
typename RecipeOpTy,
typename AccOpTy>
171 LogicalResult materialize(OpTy op, RecipeOpTy recipe, AccOpTy accOp,
173 template <
typename OpTy>
174 LogicalResult materializeForACCOp(OpTy accOp,
178void ACCRecipeMaterialization::handleFirstprivateMapping(
179 acc::FirstprivateOp firstprivateOp)
const {
181 auto mapFirstprivateOp = acc::FirstprivateMapInitialOp::create(
182 builder, firstprivateOp.getLoc(), firstprivateOp.getVar(),
183 firstprivateOp.getStructured(), firstprivateOp.getImplicit(),
184 firstprivateOp.getBounds());
185 mapFirstprivateOp.setName(firstprivateOp.getName());
186 firstprivateOp.getVarMutable().assign(mapFirstprivateOp.getAccVar());
189template <
typename OpTy>
190void ACCRecipeMaterialization::removeRecipe(OpTy op, ModuleOp moduleOp)
const {
191 auto recipeName = op.getNameAttr();
193 LLVM_DEBUG(llvm::dbgs() <<
"erasing recipe: " << recipeName <<
"\n");
197 std::optional<SymbolTable::UseRange> symbolUses =
198 op.getSymbolUses(moduleOp);
199 if (symbolUses.has_value()) {
201 llvm::dbgs() <<
"symbol use: ";
202 symbolUse.getUser()->dump();
206 llvm_unreachable(
"expected no use of recipe symbol");
210template <
typename OpTy,
typename RecipeOpTy,
typename AccOpTy>
212ACCRecipeMaterialization::materialize(OpTy op, RecipeOpTy recipe, AccOpTy accOp,
214 Region ®ion = accOp.getRegion();
215 Value origPtr = op.getVar();
216 Value accPtr = op.getAccVar();
217 assert(accPtr &&
"invalid op: null acc var");
223 Region &initRegion = recipe.getInitRegion();
224 unsigned initNumArguments =
225 initRegion.
getBlocks().front().getArguments().size();
226 if (initNumArguments > 1) {
229 if ((initNumArguments - 1) % 3 != 0) {
231 "privatization of array section with extents");
242 initRegion.
getBlocks().front().getArgument(argIdx++).getType(),
246 auto dataBound = bound.getDefiningOp<acc::DataBoundsOp>();
248 "acc.reduction's bound must be defined by acc.bounds");
253 castValueToArgType(dataBound.getLoc(), dataBound.getLowerbound());
255 castValueToArgType(dataBound.getLoc(), dataBound.getUpperbound());
257 castValueToArgType(dataBound.getLoc(), dataBound.getStride());
258 triples.append({lb,
ub, step});
260 assert(triples.size() + 1 == initNumArguments &&
261 "mismatch between number bounds and number of recipe init block "
267 initArgs.append(triples);
268 mapping.
map(initRegion.
getBlocks().front().getArguments(), initArgs);
270 if constexpr (std::is_same_v<OpTy, acc::PrivateOp>) {
274 &initRegion, block, block->
begin(), mapping, {accPtr});
275 assert(results.size() == 1 &&
"expected single result from init region");
276 saveVarName(op.getAccVar(), results[0]);
279 if (!recipe.getDestroyRegion().empty()) {
280 results.insert(results.begin(), origPtr);
281 results.append(triples);
282 cloneDestroy(recipe, block, std::prev(block->
end()), results);
284 }
else if constexpr (std::is_same_v<OpTy, acc::FirstprivateOp>) {
288 &initRegion, block, block->
begin(), mapping, {accPtr});
289 assert(results.size() == 1 &&
"expected single result from init region");
290 saveVarName(op.getAccVar(), results[0]);
293 results.insert(results.begin(), origPtr);
294 results.append(triples);
298 mapping.
map(recipe.getCopyRegion().front().getArguments(), results);
302 if (!recipe.getDestroyRegion().empty()) {
304 cloneDestroy(recipe, block, std::prev(block->
end()), results);
306 }
else if constexpr (std::is_same_v<OpTy, acc::ReductionOp>) {
307 auto cloneRegionIntoAccRegion = [&](
Region *src,
Region *dest,
312 b.setInsertionPoint(terminator);
314 acc::YieldOp::create(
b, op.getLoc(), terminator->
getOperands());
316 acc::YieldOp::create(
b, op.getLoc(),
ValueRange{});
321 if constexpr (std::is_same_v<AccOpTy, acc::ParallelOp>)
322 b.setInsertionPointToStart(®ion.
front());
323 else if constexpr (std::is_same_v<AccOpTy, acc::LoopOp>)
324 b.setInsertionPoint(op);
326 llvm_unreachable(
"unexpected acc op with reduction recipe");
328 auto reductionOp = acc::ReductionInitOp::create(
329 b, op.getLoc(), origPtr, recipe.getReductionOperatorAttr());
330 saveVarName(op.getAccVar(), reductionOp.getResult());
331 cloneRegionIntoAccRegion(&initRegion, &reductionOp.getRegion(),
333 Block *initBlock = &reductionOp.getRegion().
front();
334 resolveVarNamePlaceholders(initBlock, std::prev(initBlock->
end()),
341 Region &combinerRegion = recipe.getCombinerRegion();
344 if constexpr (std::is_same_v<AccOpTy, acc::ParallelOp>)
346 else if constexpr (std::is_same_v<AccOpTy, acc::LoopOp>)
347 b.setInsertionPointAfter(accOp);
349 llvm_unreachable(
"unexpected acc op with reduction recipe");
357 argsRemapping.append(triples);
360 auto combineRegionOp = acc::ReductionCombineRegionOp::create(
361 b, op.getLoc(), origPtr, reductionOp.getResult());
362 cloneRegionIntoAccRegion(&combinerRegion, &combineRegionOp.getRegion(),
365 auto setSeqParDimsForRecipeLoops = [](
Region *r) {
366 r->walk([](LoopLikeOpInterface loopLike) {
368 acc::GPUParallelDimsAttr::name,
369 acc::GPUParallelDimsAttr::seq(loopLike->getContext()));
372 setSeqParDimsForRecipeLoops(&reductionOp.getRegion());
373 setSeqParDimsForRecipeLoops(&combineRegionOp.getRegion());
375 if (!recipe.getDestroyRegion().empty()) {
378 cloneDestroy(recipe, combineRegionOp->getBlock(), ip, results);
381 llvm_unreachable(
"unexpected op type");
388template <
typename OpTy>
389LogicalResult ACCRecipeMaterialization::materializeForACCOp(
391 assert(isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(accOp));
393 if (!accOp.getFirstprivateOperands().empty()) {
397 accOp.getFirstprivateOperandsMutable().clear();
398 for (
Value operand : operands) {
399 auto firstprivateOp = cast<acc::FirstprivateOp>(operand.getDefiningOp());
400 auto symbolRef = cast<SymbolRefAttr>(firstprivateOp.getRecipeAttr());
402 auto recipeOp = cast<acc::FirstprivateRecipeOp>(decl);
403 LLVM_DEBUG(llvm::dbgs() <<
"materializing: " << firstprivateOp <<
"\n"
404 << symbolRef <<
"\n");
405 handleFirstprivateMapping(firstprivateOp);
406 if (failed(materialize(firstprivateOp, recipeOp, accOp, accSupport)))
411 if (!accOp.getPrivateOperands().empty()) {
415 accOp.getPrivateOperandsMutable().clear();
416 for (
Value operand : operands) {
417 auto privateOp = cast<acc::PrivateOp>(operand.getDefiningOp());
418 auto symbolRef = cast<SymbolRefAttr>(privateOp.getRecipeAttr());
420 auto recipeOp = cast<acc::PrivateRecipeOp>(decl);
421 LLVM_DEBUG(llvm::dbgs() <<
"materializing: " << privateOp <<
"\n"
422 << symbolRef <<
"\n");
423 if (failed(materialize(privateOp, recipeOp, accOp, accSupport)))
428 if (!accOp.getReductionOperands().empty()) {
432 accOp.getReductionOperandsMutable().clear();
433 for (
Value operand : operands) {
434 auto reductionOp = cast<acc::ReductionOp>(operand.getDefiningOp());
435 auto symbolRef = cast<SymbolRefAttr>(reductionOp.getRecipeAttr());
437 auto recipeOp = cast<acc::ReductionRecipeOp>(decl);
438 LLVM_DEBUG(llvm::dbgs() <<
"materializing: " << reductionOp <<
"\n"
439 << symbolRef <<
"\n");
440 if (failed(materialize(reductionOp, recipeOp, accOp, accSupport)))
447void ACCRecipeMaterialization::runOnOperation() {
448 ModuleOp moduleOp = getOperation();
452 bool anyFailed =
false;
457 [&](
auto constructOp) {
458 if (failed(materializeForACCOp(constructOp, accSupport)))
469 if (
auto recipe = dyn_cast<acc::ReductionRecipeOp>(op))
470 removeRecipe(recipe, moduleOp);
471 else if (
auto recipe = dyn_cast<acc::PrivateRecipeOp>(op))
472 removeRecipe(recipe, moduleOp);
473 else if (
auto recipe = dyn_cast<acc::FirstprivateRecipeOp>(op))
474 removeRecipe(recipe, moduleOp);
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
This is a utility class for mapping one set of IR entities to another.
void clear()
Clears all mappings held by the mapper.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
MLIRContext * getContext()
Return the context this operation is associated with.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
This class represents a specific symbol use.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
InFlightDiagnostic emitNYI(Location loc, const Twine &message)
Report a case that is not yet supported by the implementation.
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS
std::string getVariableName(mlir::Value v)
Attempts to extract the variable name from a value by walking through view-like operations until an a...
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
llvm::StringLiteral getVarNamePlaceholder()
Returns a placeholder string for use as an acc.var_name attribute value when the actual variable name...
static constexpr StringLiteral getVarNameAttrName()
std::pair< llvm::SmallVector< Value >, Block::iterator > cloneACCRegionInto(Region *src, Block *dest, Block::iterator inlinePoint, IRMapping &mapping, ValueRange resultsToReplace)
Clone an ACC region into a destination block at the given insertion point.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
llvm::TypeSwitch< T, ResultT > TypeSwitch