24 int32_t defaultTileSize,
29 if (constVal && *constVal < 0) {
31 return mlir::arith::ConstantOp::create(
32 rewriter, loc, targetType,
41 if (loop.hasVector() || loop.getVectorValue()) {
42 loop.removeVectorAttr();
43 loop.removeVectorOperandsDeviceTypeAttr();
44 }
else if (loop.hasWorker() || loop.getWorkerValue()) {
45 loop.removeWorkerAttr();
46 loop.removeWorkerNumOperandsDeviceTypeAttr();
51static mlir::acc::LoopOp
56 mlir::acc::CombinedConstructsTypeAttr combinedAttr,
58 mlir::ArrayAttr collapseAttr = mlir::ArrayAttr{};
59 mlir::ArrayAttr collapseDeviceTypeAttr = mlir::ArrayAttr{};
60 if (preserveCollapse) {
61 collapseAttr = origLoop.getCollapseAttr();
62 collapseDeviceTypeAttr = origLoop.getCollapseDeviceTypeAttr();
64 auto newLoop = mlir::acc::LoopOp::create(
65 rewriter, loc, origLoop->getResultTypes(), lb, ub, step, inclusiveUBAttr,
66 collapseAttr, collapseDeviceTypeAttr, origLoop.getGangOperands(),
67 origLoop.getGangOperandsArgTypeAttr(),
68 origLoop.getGangOperandsSegmentsAttr(),
69 origLoop.getGangOperandsDeviceTypeAttr(), origLoop.getWorkerNumOperands(),
70 origLoop.getWorkerNumOperandsDeviceTypeAttr(),
71 origLoop.getVectorOperands(), origLoop.getVectorOperandsDeviceTypeAttr(),
72 origLoop.getSeqAttr(), origLoop.getIndependentAttr(),
73 origLoop.getAuto_Attr(), origLoop.getGangAttr(), origLoop.getWorkerAttr(),
75 mlir::ArrayAttr{}, origLoop.getCacheOperands(),
76 origLoop.getPrivateOperands(), origLoop.getFirstprivateOperands(),
77 origLoop.getReductionOperands(), combinedAttr);
82static mlir::acc::LoopOp
87 inputLoop, rewriter, lb, ub, step, inclusiveUBAttr,
88 mlir::acc::CombinedConstructsTypeAttr{}, loc,
false);
92 if (inputLoop.hasGang() ||
93 inputLoop.getGangValue(mlir::acc::GangArgType::Num) ||
94 inputLoop.getGangValue(mlir::acc::GangArgType::Dim) ||
95 inputLoop.getGangValue(mlir::acc::GangArgType::Static)) {
96 elementLoop.removeGangAttr();
97 elementLoop.removeGangOperandsArgTypeAttr();
98 elementLoop.removeGangOperandsSegmentsAttr();
99 elementLoop.removeGangOperandsDeviceTypeAttr();
101 if (inputLoop.hasVector() || inputLoop.getVectorValue()) {
102 elementLoop.removeWorkerAttr();
103 elementLoop.removeWorkerNumOperandsDeviceTypeAttr();
109 elementLoop.getRegion().begin());
111 mlir::acc::YieldOp::create(rewriter, loc);
112 elementLoop.getBody().addArgument(
113 inputLoop.getBody().getArgument(0).getType(), loc);
120 mlir::acc::LoopOp targetLoop,
136 movedOps.push_back(op);
137 rewriter.startOpModification(op);
140 targetLoop.getBody().getOperations().splice(
141 targetLoop.getBody().getOperations().begin(),
142 sourceLoop.getBody().getOperations(), begin, end);
145 for (
auto [i, newIV] : llvm::enumerate(newIVs))
157 mlir::acc::LoopOp outerLoop = tileLoops[0];
160 mlir::acc::LoopOp innerLoop = tileLoops[tileLoops.size() - 1];
167 size_t nOps = innerLoop.getBody().getOperations().size();
171 for (
auto tileLoop : tileLoops) {
172 for (
auto [
j, step] : llvm::enumerate(tileLoop.getStep())) {
176 if (tileLoop.getInclusiveUpperboundAttr())
177 inclusiveUBs.push_back(
178 tileLoop.getInclusiveUpperboundAttr().asArrayRef()[
j]);
180 inclusiveUBs.push_back(
false);
186 for (
auto [i, tileLoop] : llvm::enumerate(tileLoops)) {
187 for (
auto arg : tileLoop.getBody().getArguments())
188 origIVs.push_back(arg);
189 for (
auto ub : tileLoop.getUpperbound())
190 origUBs.push_back(
ub);
193 for (
auto [
j, step] : llvm::enumerate(tileLoop.getStep())) {
194 origSteps.push_back(step);
195 if (i +
j >= tileSizes.size()) {
196 currentLoopSteps.push_back(step);
199 tileSizes[i +
j], defaultTileSize, step.getType(), rewriter, loc);
201 mlir::arith::MulIOp::create(rewriter, loc, step, tileSize);
202 currentLoopSteps.push_back(newLoopStep);
203 newSteps.push_back(newLoopStep);
208 tileLoop.getStepMutable().clear();
209 tileLoop.getStepMutable().append(currentLoopSteps);
214 for (
size_t i = 0; i < newSteps.size(); i++) {
218 mlir::arith::AddIOp::create(rewriter, loc, origIVs[i], newSteps[i]);
220 if (inclusiveUBs[i]) {
223 auto c1 = mlir::arith::ConstantOp::create(
224 rewriter, loc, newSteps[i].
getType(),
226 newUB = mlir::arith::SubIOp::create(rewriter, loc, stepped, c1);
229 mlir::arith::MinSIOp::create(rewriter, loc, origUBs[i], newUB));
233 mlir::acc::LoopOp currentLoop = innerLoop;
234 for (
size_t i = 0; i < tileSizes.size(); i++) {
240 mlir::acc::LoopOp elementLoop =
252 newIVs.push_back(elementLoop.getBody().getArgument(0));
253 currentLoop = elementLoop;
257 for (
auto tileLoop : tileLoops) {
271 unsigned collapseCount,
278 for (
unsigned i = 0; i < collapseCount; i++) {
280 bool inclusiveUB =
false;
281 if (origLoop.getInclusiveUpperboundAttr())
282 inclusiveUB = origLoop.getInclusiveUpperboundAttr().asArrayRef()[i];
283 newInclusiveUBs.push_back(inclusiveUB);
284 lbs.push_back(origLoop.getLowerbound()[i]);
285 ubs.push_back(origLoop.getUpperbound()[i]);
286 steps.push_back(origLoop.getStep()[i]);
289 origLoop, rewriter, lbs, ubs, steps,
291 origLoop.getCombinedAttr(), loc,
true);
293 outerLoop.getRegion().begin());
295 mlir::acc::YieldOp::create(rewriter, loc);
296 for (
unsigned i = 0; i < collapseCount; i++) {
297 outerLoop.getBody().addArgument(origLoop.getBody().getArgument(i).getType(),
299 newIVs.push_back(outerLoop.getBody().getArgument(i));
301 newLoops.push_back(outerLoop);
303 mlir::acc::LoopOp currentLoopOp = outerLoop;
304 for (
unsigned i = collapseCount; i < tileCount; i++) {
306 bool inclusiveUB =
false;
307 if (origLoop.getInclusiveUpperboundAttr())
308 inclusiveUB = origLoop.getInclusiveUpperboundAttr().asArrayRef()[i];
315 newIVs.push_back(innerLoop.getBody().getArgument(0));
316 newLoops.push_back(innerLoop);
317 currentLoopOp = innerLoop;
320 size_t nOps = origLoop.getBody().getOperations().size();
322 for (
auto arg : origLoop.getBody().getArguments())
323 origIVs.push_back(arg);
static void removeWorkerVectorFromLoop(mlir::acc::LoopOp loop)
static mlir::acc::LoopOp createACCLoopFromOriginal(mlir::acc::LoopOp origLoop, mlir::RewriterBase &rewriter, mlir::ValueRange lb, mlir::ValueRange ub, mlir::ValueRange step, mlir::DenseBoolArrayAttr inclusiveUBAttr, mlir::acc::CombinedConstructsTypeAttr combinedAttr, mlir::Location loc, bool preserveCollapse)
static void moveOpsAndReplaceIVs(mlir::acc::LoopOp sourceLoop, mlir::acc::LoopOp targetLoop, llvm::ArrayRef< mlir::Value > newIVs, llvm::ArrayRef< mlir::Value > origIVs, size_t nOps, mlir::RewriterBase &rewriter)
static mlir::Value resolveAndCastTileSize(mlir::Value tileSize, int32_t defaultTileSize, mlir::Type targetType, mlir::RewriterBase &rewriter, mlir::Location loc)
static mlir::acc::LoopOp createInnerLoop(mlir::acc::LoopOp inputLoop, mlir::RewriterBase &rewriter, mlir::ValueRange lb, mlir::ValueRange ub, mlir::ValueRange step, mlir::DenseBoolArrayAttr inclusiveUBAttr, mlir::Location loc)
Block represents an ordered list of Operations.
OpListType::iterator iterator
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef< bool > values)
Tensor-typed DenseArrayAttr getters.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
mlir::acc::LoopOp tileACCLoops(llvm::SmallVector< mlir::acc::LoopOp > &tileLoops, const llvm::SmallVector< mlir::Value > &tileSizes, int32_t defaultTileSize, mlir::RewriterBase &rewriter)
Tile ACC loops according to the given tile sizes.
llvm::SmallVector< mlir::acc::LoopOp > uncollapseLoops(mlir::acc::LoopOp origLoop, unsigned tileCount, unsigned collapseCount, mlir::RewriterBase &rewriter)
Uncollapse tile loops with multiple IVs and collapseCount < tileCount.
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.