18#include "llvm/ADT/STLExtras.h"
33 if (
auto mod = llvm::dyn_cast<ModuleOp>(current)) {
34 if (mod.getDataLayoutSpec())
36 }
else if (
auto dataLayoutOp =
37 llvm::dyn_cast<DataLayoutOpInterface>(current)) {
39 if (dataLayoutOp.getDataLayoutSpec())
48 if (
auto mod = llvm::dyn_cast<ModuleOp>(op))
67 for (
auto val : output)
68 resultTypes.push_back(val.getType());
70 ComputeRegionOp::create(rewriter, loc, resultTypes, launchArgs, inputArgs,
71 stream, origin, kernelFuncName, kernelModuleName);
73 assert(!regionToClone.
getBlocks().empty() &&
74 "empty region for acc.compute_region");
77 ValueRange mapKeys = inputArgsToMap.empty() ? inputArgs : inputArgsToMap;
78 assert(mapKeys.size() == inputArgs.size() &&
79 "inputArgsToMap must have same size as inputArgs when provided");
83 for (
size_t i = 0; i < launchArgs.size(); ++i)
85 for (
Value input : inputArgs)
87 for (
size_t i = 0; i < inputArgs.size(); ++i)
88 mapping.
map(mapKeys[i], entryBlock->
getArgument(launchArgs.size() + i));
90 if (regionToClone.
getBlocks().size() == 1) {
94 rewriter.
clone(op, mapping);
97 for (
auto val : output)
98 yieldOperands.push_back(mapping.
lookup(val));
100 YieldOp::create(rewriter, loc, yieldOperands);
103 regionToClone, mapping, loc, rewriter);
105 rewriter.
eraseOp(computeRegion);
109 llvm::to_vector(exeRegion.getOps<scf::YieldOp>()));
110 assert(!yieldOps.empty() &&
111 "multi-block region must contain at least one scf.yield");
112 assert(llvm::all_of(yieldOps,
113 [&output](scf::YieldOp yieldOp) {
114 return yieldOp.getNumOperands() ==
115 static_cast<int64_t>(output.size()) &&
117 llvm::zip(yieldOp.getOperands(), output),
119 return std::get<0>(pair).getType() ==
120 std::get<1>(pair).getType();
123 "each scf.yield operand count and types must match output");
125 YieldOp::create(rewriter, loc, exeRegion.getResults());
128 return computeRegion;
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
OpListType & getOperations()
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
The main mechanism for performing data layout queries.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::optional< DataLayout > getDataLayout(Operation *op, bool allowDefault=true)
Get the data layout for an operation.
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs, llvm::StringRef origin, Region ®ionToClone, RewriterBase &rewriter, IRMapping &mapping, ValueRange output={}, FlatSymbolRefAttr kernelFuncName={}, FlatSymbolRefAttr kernelModuleName={}, Value stream={}, ValueRange inputArgsToMap={})
Build an acc.compute_region operation by cloning a source region.
scf::ExecuteRegionOp wrapMultiBlockRegionWithSCFExecuteRegion(Region ®ion, IRMapping &mapping, Location loc, RewriterBase &rewriter)
Wrap a multi-block region in an scf.execute_region.
Include the generated interface declarations.