34 b.setInsertionPoint(op);
38 unsigned insertSplitIndex = control.
index;
39 unsigned insertSplitDimension = control.
index;
41 return b.notifyMatchFailure(op,
"split ratio needs to be greater than 1");
44 op.getReductionDims(dims);
47 return b.notifyMatchFailure(op,
"needs a single reduction dimension");
48 unsigned reductionDim = dims[0];
50 insertSplitDimension = reductionDim + 1;
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55 return b.notifyMatchFailure(
56 op,
"Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
58 return b.notifyMatchFailure(op,
"More than one output in split reduction");
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60 return b.notifyMatchFailure(op,
"Insert dimension position too large "
61 "compared to intermediate tensor size");
65 combinerOps.size() != 1)
66 return b.notifyMatchFailure(op,
"Cannot match the reduction pattern");
70 if (!identity.has_value())
71 return b.notifyMatchFailure(op,
"Unknown identity value for the reduction");
77 for (
OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
83 for (
unsigned idx : llvm::seq<unsigned>(0, map.
getNumResults())) {
85 if (reductionDim == dim) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio);
88 newShape.push_back(ratio);
90 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91 exprs.push_back(
b.getAffineDimExpr(insertSplitDimension));
93 newShape.push_back(ratio);
94 newShape.push_back(op.getShape(operand)[idx] / ratio);
95 exprs.push_back(
b.getAffineDimExpr(insertSplitDimension));
97 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
102 newShape.push_back(op.getShape(operand)[idx]);
104 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105 reassociation.push_back({
index++});
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
114 Type newType = RankedTensorType::get(
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
118 Value newInput = tensor::ExpandShapeOp::create(
119 b, loc, newType, operand->get(), reassociation);
120 newInputs.push_back(newInput);
126 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
129 for (
unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130 if (insertSplitIndex == idx) {
131 newOutputShape.push_back(ratio);
132 outputExpr.push_back(
b.getAffineDimExpr(insertSplitDimension));
134 if (idx < oldShape.size()) {
135 newOutputShape.push_back(oldShape[idx]);
137 outputExpr.push_back(
138 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
141 Value emptyOrAllocTensor;
143 emptyOrAllocTensor = bufferization::AllocTensorOp::create(
145 RankedTensorType::get(newOutputShape,
146 op.getRegionOutputArgs()[0].getType()),
149 emptyOrAllocTensor = tensor::EmptyOp::create(
150 b, loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
152 Value constantOp = arith::ConstantOp::create(
b, loc, *identity);
153 Value identityTensor =
154 linalg::FillOp::create(
b, op->getLoc(), constantOp, emptyOrAllocTensor)
160 for (
auto [
index, iteratorType] :
161 llvm::enumerate(op.getIteratorTypesArray())) {
162 if (insertSplitDimension ==
index)
163 newIteratorTypes.push_back(utils::IteratorType::parallel);
164 newIteratorTypes.push_back(iteratorType);
166 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167 newIteratorTypes.push_back(utils::IteratorType::parallel);
171 GenericOp genericOp = GenericOp::create(
173 ValueRange({identityTensor}), newMaps, newIteratorTypes);
174 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175 genericOp.getRegion().begin());
179 unsigned intermRank = newOutputShape.size();
180 AffineMap inputMap =
b.getMultiDimIdentityMap(intermRank);
183 for (
unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184 if (insertSplitIndex == i) {
185 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
187 exprs.push_back(
b.getAffineDimExpr(i));
188 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
194 auto reduction = GenericOp::create(
195 b, loc, op->getResultTypes(),
ValueRange({genericOp.getResult(0)}),
196 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
198 Operation *clonedReductionOp =
b.clone(*reductionOp);
201 linalg::YieldOp::create(
b, loc, clonedReductionOp->
getResult(0));
203 b.replaceOp(op, reduction.getResults());
207 cast<LinalgOp>(genericOp.getOperation()),
245 b.setInsertionPoint(op);
250 return b.notifyMatchFailure(op,
"innerParallel not supported");
253 unsigned insertSplitDimension = control.
index;
254 if (splitFactor <= 1)
255 return b.notifyMatchFailure(op,
"split factor needs to be greater than 1");
258 op.getReductionDims(dims);
260 return b.notifyMatchFailure(op,
"needs at least 1 reduction dimension");
262 unsigned reductionDimPos = dims[0];
264 int64_t reductionDimSize = loopRanges[reductionDimPos];
265 if (reductionDimSize == ShapedType::kDynamic ||
266 reductionDimSize % splitFactor != 0 ||
267 insertSplitDimension >= loopRanges.size())
268 return b.notifyMatchFailure(
269 op,
"first reduction dimension not divisible by split factor");
273 return b.notifyMatchFailure(op,
"cannot match a reduction pattern");
276 for (
Operation *reductionOp : combinerOps) {
277 std::optional<TypedAttr> neutralElement =
279 if (!neutralElement.has_value())
280 return b.notifyMatchFailure(op,
"cannot find neutral element.");
281 neutralElements.push_back(*neutralElement);
283 if (!llvm::all_of(neutralElements, [](
Attribute attr) {
return attr; }))
284 return b.notifyMatchFailure(op,
"unknown reduction neutral");
287 if (op.getNumDpsInits() !=
static_cast<int64_t>(neutralElements.size()))
288 return b.notifyMatchFailure(op,
"expect one reduction per output");
307 newOutputs.reserve(op.getNumDpsInits());
310 fillOps.reserve(op.getNumDpsInits());
311 for (
auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312 Value rankedTensor = std::get<0>(it).get();
313 auto t = cast<RankedTensorType>(rankedTensor.
getType());
315 reductionDimSize / splitFactor, insertSplitDimension);
318 Value emptyOrAllocTensor;
321 bufferization::AllocTensorOp::create(
b, loc, newT, dims);
323 emptyOrAllocTensor = tensor::EmptyOp::create(
b, loc, newT.getShape(),
324 t.getElementType(), dims);
326 Value constantOp = arith::ConstantOp::create(
b, loc, std::get<1>(it));
327 fillOps.push_back(linalg::FillOp::create(
b, op->getLoc(), constantOp,
328 emptyOrAllocTensor));
329 newOutputs.push_back(fillOps.back().getResult(0));
330 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.
getDefiningOp());
336 newMaps.reserve(op->getNumOperands() + 1);
337 for (
OpOperand *o : op.getDpsInputOperands())
340 auto nDims = op.getNumLoops() + 1;
343 newMaps.push_back(
AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
348 for (
OpOperand &o : op.getDpsInitsMutable())
350 reductionDimSize / splitFactor));
357 newInputs.push_back(tensor::EmptyOp::create(
359 b.getIntegerType(1)));
364 auto iteratorTypes = op.getIteratorTypesArray();
365 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366 utils::IteratorType::parallel);
367 GenericOp genericOp =
368 GenericOp::create(
b, loc,
ValueRange(newOutputs).getTypes(), newInputs,
369 newOutputs, newMaps, iteratorTypes);
370 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371 genericOp.getRegion().begin());
372 genericOp.getRegion().front().insertArgument(reductionDimPos,
373 b.getIntegerType(1), loc);
384 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385 Value reindexedOutput = std::get<0>(it);
386 Value originalOutput = std::get<1>(it);
387 auto originalOutputType = cast<RankedTensorType>(originalOutput.
getType());
390 AffineMap map =
b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
394 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395 reductionIteratorTypes[insertSplitDimension] =
396 utils::IteratorType::reduction;
399 auto reductionOp = GenericOp::create(
b,
405 reductionIteratorTypes,
407 Operation *clonedReductionOp =
b.clone(*combinerOp);
410 linalg::YieldOp::create(
b, loc, clonedReductionOp->
getResult(0));
414 results.push_back(reductionOp);
418 assert(fillOps.size() == results.size() && results.size() == 1);
419 b.replaceOp(op, results.front()->getResults());
421 cast<LinalgOp>(genericOp.getOperation()),
MLIRContext is the top-level object for a collection of MLIR operations.