24 int32_t defaultTileSize,
29 if (constVal && *constVal < 0) {
31 return mlir::arith::ConstantOp::create(
32 rewriter, loc, targetType,
41 if (loop.hasVector() || loop.getVectorValue()) {
42 loop.removeVectorAttr();
43 loop.removeVectorOperandsDeviceTypeAttr();
44 }
else if (loop.hasWorker() || loop.getWorkerValue()) {
45 loop.removeWorkerAttr();
46 loop.removeWorkerNumOperandsDeviceTypeAttr();
51static mlir::acc::LoopOp
56 mlir::acc::CombinedConstructsTypeAttr combinedAttr,
58 mlir::ArrayAttr collapseAttr = mlir::ArrayAttr{};
59 mlir::ArrayAttr collapseDeviceTypeAttr = mlir::ArrayAttr{};
60 if (preserveCollapse) {
61 collapseAttr = origLoop.getCollapseAttr();
62 collapseDeviceTypeAttr = origLoop.getCollapseDeviceTypeAttr();
64 auto newLoop = mlir::acc::LoopOp::create(
65 rewriter, loc, origLoop->getResultTypes(), lb, ub, step, inclusiveUBAttr,
66 collapseAttr, collapseDeviceTypeAttr, origLoop.getGangOperands(),
67 origLoop.getGangOperandsArgTypeAttr(),
68 origLoop.getGangOperandsSegmentsAttr(),
69 origLoop.getGangOperandsDeviceTypeAttr(), origLoop.getWorkerNumOperands(),
70 origLoop.getWorkerNumOperandsDeviceTypeAttr(),
71 origLoop.getVectorOperands(), origLoop.getVectorOperandsDeviceTypeAttr(),
72 origLoop.getSeqAttr(), origLoop.getIndependentAttr(),
73 origLoop.getAuto_Attr(), origLoop.getGangAttr(), origLoop.getWorkerAttr(),
75 mlir::ArrayAttr{}, origLoop.getCacheOperands(),
76 origLoop.getPrivateOperands(), origLoop.getFirstprivateOperands(),
77 origLoop.getReductionOperands(), combinedAttr);
82static mlir::acc::LoopOp
87 inputLoop, rewriter, lb, ub, step, inclusiveUBAttr,
88 mlir::acc::CombinedConstructsTypeAttr{}, loc,
false);
92 if (inputLoop.hasGang() ||
93 inputLoop.getGangValue(mlir::acc::GangArgType::Num) ||
94 inputLoop.getGangValue(mlir::acc::GangArgType::Dim) ||
95 inputLoop.getGangValue(mlir::acc::GangArgType::Static)) {
96 elementLoop.removeGangAttr();
97 elementLoop.removeGangOperandsArgTypeAttr();
98 elementLoop.removeGangOperandsSegmentsAttr();
99 elementLoop.removeGangOperandsDeviceTypeAttr();
101 if (inputLoop.hasVector() || inputLoop.getVectorValue()) {
102 elementLoop.removeWorkerAttr();
103 elementLoop.removeWorkerNumOperandsDeviceTypeAttr();
109 elementLoop.getRegion().begin());
111 mlir::acc::YieldOp::create(rewriter, loc);
112 elementLoop.getBody().addArgument(
113 inputLoop.getBody().getArgument(0).getType(), loc);
120 mlir::acc::LoopOp targetLoop,
126 targetLoop.getBody().getOperations().splice(
127 targetLoop.getBody().getOperations().begin(),
128 sourceLoop.getBody().getOperations(), begin, std::next(begin, nOps - 1));
131 for (
auto [i, newIV] : llvm::enumerate(newIVs))
140 mlir::acc::LoopOp outerLoop = tileLoops[0];
143 mlir::acc::LoopOp innerLoop = tileLoops[tileLoops.size() - 1];
150 size_t nOps = innerLoop.getBody().getOperations().size();
154 for (
auto tileLoop : tileLoops) {
155 for (
auto [
j, step] : llvm::enumerate(tileLoop.getStep())) {
159 if (tileLoop.getInclusiveUpperboundAttr())
160 inclusiveUBs.push_back(
161 tileLoop.getInclusiveUpperboundAttr().asArrayRef()[
j]);
163 inclusiveUBs.push_back(
false);
169 for (
auto [i, tileLoop] : llvm::enumerate(tileLoops)) {
170 for (
auto arg : tileLoop.getBody().getArguments())
171 origIVs.push_back(arg);
172 for (
auto ub : tileLoop.getUpperbound())
173 origUBs.push_back(
ub);
176 for (
auto [
j, step] : llvm::enumerate(tileLoop.getStep())) {
177 origSteps.push_back(step);
178 if (i +
j >= tileSizes.size()) {
179 currentLoopSteps.push_back(step);
182 tileSizes[i +
j], defaultTileSize, step.getType(), rewriter, loc);
184 mlir::arith::MulIOp::create(rewriter, loc, step, tileSize);
185 currentLoopSteps.push_back(newLoopStep);
186 newSteps.push_back(newLoopStep);
191 tileLoop.getStepMutable().clear();
192 tileLoop.getStepMutable().append(currentLoopSteps);
197 for (
size_t i = 0; i < newSteps.size(); i++) {
201 mlir::arith::AddIOp::create(rewriter, loc, origIVs[i], newSteps[i]);
203 if (inclusiveUBs[i]) {
206 auto c1 = mlir::arith::ConstantOp::create(
207 rewriter, loc, newSteps[i].
getType(),
209 newUB = mlir::arith::SubIOp::create(rewriter, loc, stepped, c1);
212 mlir::arith::MinSIOp::create(rewriter, loc, origUBs[i], newUB));
216 mlir::acc::LoopOp currentLoop = innerLoop;
217 for (
size_t i = 0; i < tileSizes.size(); i++) {
223 mlir::acc::LoopOp elementLoop =
235 newIVs.push_back(elementLoop.getBody().getArgument(0));
236 currentLoop = elementLoop;
240 for (
auto tileLoop : tileLoops) {
254 unsigned collapseCount,
261 for (
unsigned i = 0; i < collapseCount; i++) {
263 bool inclusiveUB =
false;
264 if (origLoop.getInclusiveUpperboundAttr())
265 inclusiveUB = origLoop.getInclusiveUpperboundAttr().asArrayRef()[i];
266 newInclusiveUBs.push_back(inclusiveUB);
267 lbs.push_back(origLoop.getLowerbound()[i]);
268 ubs.push_back(origLoop.getUpperbound()[i]);
269 steps.push_back(origLoop.getStep()[i]);
272 origLoop, rewriter, lbs, ubs, steps,
274 origLoop.getCombinedAttr(), loc,
true);
276 outerLoop.getRegion().begin());
278 mlir::acc::YieldOp::create(rewriter, loc);
279 for (
unsigned i = 0; i < collapseCount; i++) {
280 outerLoop.getBody().addArgument(origLoop.getBody().getArgument(i).getType(),
282 newIVs.push_back(outerLoop.getBody().getArgument(i));
284 newLoops.push_back(outerLoop);
286 mlir::acc::LoopOp currentLoopOp = outerLoop;
287 for (
unsigned i = collapseCount; i < tileCount; i++) {
289 bool inclusiveUB =
false;
290 if (origLoop.getInclusiveUpperboundAttr())
291 inclusiveUB = origLoop.getInclusiveUpperboundAttr().asArrayRef()[i];
298 newIVs.push_back(innerLoop.getBody().getArgument(0));
299 newLoops.push_back(innerLoop);
300 currentLoopOp = innerLoop;
303 size_t nOps = origLoop.getBody().getOperations().size();
305 for (
auto arg : origLoop.getBody().getArguments())
306 origIVs.push_back(arg);
static void removeWorkerVectorFromLoop(mlir::acc::LoopOp loop)
static mlir::acc::LoopOp createACCLoopFromOriginal(mlir::acc::LoopOp origLoop, mlir::RewriterBase &rewriter, mlir::ValueRange lb, mlir::ValueRange ub, mlir::ValueRange step, mlir::DenseBoolArrayAttr inclusiveUBAttr, mlir::acc::CombinedConstructsTypeAttr combinedAttr, mlir::Location loc, bool preserveCollapse)
static void moveOpsAndReplaceIVs(mlir::acc::LoopOp sourceLoop, mlir::acc::LoopOp targetLoop, llvm::ArrayRef< mlir::Value > newIVs, llvm::ArrayRef< mlir::Value > origIVs, size_t nOps, mlir::RewriterBase &rewriter)
static mlir::Value resolveAndCastTileSize(mlir::Value tileSize, int32_t defaultTileSize, mlir::Type targetType, mlir::RewriterBase &rewriter, mlir::Location loc)
static mlir::acc::LoopOp createInnerLoop(mlir::acc::LoopOp inputLoop, mlir::RewriterBase &rewriter, mlir::ValueRange lb, mlir::ValueRange ub, mlir::ValueRange step, mlir::DenseBoolArrayAttr inclusiveUBAttr, mlir::Location loc)
Block represents an ordered list of Operations.
OpListType::iterator iterator
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
mlir::acc::LoopOp tileACCLoops(llvm::SmallVector< mlir::acc::LoopOp > &tileLoops, const llvm::SmallVector< mlir::Value > &tileSizes, int32_t defaultTileSize, mlir::RewriterBase &rewriter)
Tile ACC loops according to the given tile sizes.
llvm::SmallVector< mlir::acc::LoopOp > uncollapseLoops(mlir::acc::LoopOp origLoop, unsigned tileCount, unsigned collapseCount, mlir::RewriterBase &rewriter)
Uncollapse tile loops with multiple IVs and collapseCount < tileCount.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.