MLIR 22.0.0git
SplitReduction.cpp
Go to the documentation of this file.
1//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements linalg transformation to break a reduction dimension
10// between a parallel and a reduction dimension.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15#include <utility>
16
26
27using namespace mlir;
28using namespace mlir::linalg;
29
30FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
31 RewriterBase &b, LinalgOp op,
32 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
34 b.setInsertionPoint(op);
35
36 SplitReductionOptions control = controlSplitReductionFn(op);
37 int64_t ratio = control.ratio;
38 unsigned insertSplitIndex = control.index;
39 unsigned insertSplitDimension = control.index;
40 if (ratio <= 1)
41 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42
44 op.getReductionDims(dims);
45
46 if (dims.size() != 1)
47 return b.notifyMatchFailure(op, "needs a single reduction dimension");
48 unsigned reductionDim = dims[0];
49 if (control.innerParallel) {
50 insertSplitDimension = reductionDim + 1;
51 }
52 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55 return b.notifyMatchFailure(
56 op, "Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
58 return b.notifyMatchFailure(op, "More than one output in split reduction");
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60 return b.notifyMatchFailure(op, "Insert dimension position too large "
61 "compared to intermediate tensor size");
62
64 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65 combinerOps.size() != 1)
66 return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67
68 Operation *reductionOp = combinerOps[0];
69 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70 if (!identity.has_value())
71 return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72
73 Location loc = op->getLoc();
74 SmallVector<Value> newInputs;
76 // Calculate the new shapes and indexing maps of the input operands.
77 for (OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
79 SmallVector<int64_t> newShape;
82 unsigned index = 0;
83 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84 unsigned dim = map.getDimPosition(idx);
85 if (reductionDim == dim) {
86 if (control.innerParallel) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88 newShape.push_back(ratio); // parallel (insert)
89 exprs.push_back(
90 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92 } else {
93 newShape.push_back(ratio); // parallel (insert)
94 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96 exprs.push_back(
97 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98 }
99 reassociation.push_back({index++, index++});
100 continue;
101 }
102 newShape.push_back(op.getShape(operand)[idx]);
103 exprs.push_back(
104 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105 reassociation.push_back({index++});
106 }
107 newMaps.push_back(
108 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109 // If the shape is unchanged the input doesn't change.
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
112 continue;
113 }
114 Type newType = RankedTensorType::get(
115 newShape,
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
117
118 Value newInput = tensor::ExpandShapeOp::create(
119 b, loc, newType, operand->get(), reassociation);
120 newInputs.push_back(newInput);
121 }
122
123 // Calculate the new output map and shape, we insert the new dimension based
124 // on the index returned by `controlSplitReductionFn`.
125 SmallVector<int64_t> newOutputShape;
126 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127 ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
128 SmallVector<AffineExpr> outputExpr;
129 for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130 if (insertSplitIndex == idx) {
131 newOutputShape.push_back(ratio);
132 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
133 }
134 if (idx < oldShape.size()) {
135 newOutputShape.push_back(oldShape[idx]);
136 unsigned dim = oldOutputMap.getDimPosition(idx);
137 outputExpr.push_back(
138 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139 }
140 }
141 Value emptyOrAllocTensor;
142 if (useAlloc) {
143 emptyOrAllocTensor = bufferization::AllocTensorOp::create(
144 b, loc,
145 RankedTensorType::get(newOutputShape,
146 op.getRegionOutputArgs()[0].getType()),
147 ValueRange{});
148 } else {
149 emptyOrAllocTensor = tensor::EmptyOp::create(
150 b, loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151 }
152 Value constantOp = arith::ConstantOp::create(b, loc, *identity);
153 Value identityTensor =
154 linalg::FillOp::create(b, op->getLoc(), constantOp, emptyOrAllocTensor)
155 .getResult(0);
156
157 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
158 op.getContext()));
159 SmallVector<utils::IteratorType> newIteratorTypes;
160 for (auto [index, iteratorType] :
161 llvm::enumerate(op.getIteratorTypesArray())) {
162 if (insertSplitDimension == index)
163 newIteratorTypes.push_back(utils::IteratorType::parallel);
164 newIteratorTypes.push_back(iteratorType);
165 }
166 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167 newIteratorTypes.push_back(utils::IteratorType::parallel);
168 }
169 // Create the new op matching the original op with an extra parallel
170 // dimension.
171 GenericOp genericOp = GenericOp::create(
172 b, loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
173 ValueRange({identityTensor}), newMaps, newIteratorTypes);
174 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175 genericOp.getRegion().begin());
176
177 // Then create a new reduction that only reduce the newly added dimension
178 // from the previous op.
179 unsigned intermRank = newOutputShape.size();
180 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181 SmallVector<utils::IteratorType> reductionIteratorTypes;
183 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184 if (insertSplitIndex == i) {
185 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
186 } else {
187 exprs.push_back(b.getAffineDimExpr(i));
188 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
189 }
190 }
191 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
192 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
193
194 auto reduction = GenericOp::create(
195 b, loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
196 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
197 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
198 Operation *clonedReductionOp = b.clone(*reductionOp);
199 clonedReductionOp->setOperand(0, inputs[0]);
200 clonedReductionOp->setOperand(1, inputs[1]);
201 linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
202 });
203 b.replaceOp(op, reduction.getResults());
204
205 return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
206 identityTensor.getDefiningOp<FillOp>(),
207 cast<LinalgOp>(genericOp.getOperation()),
208 reduction};
209}
210
211/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214/// done as a transform to enable better vectorization.
215static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216 unsigned reductionDimPos,
217 int64_t reductionRatio) {
218 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
220 AffineMap map = op.getMatchingIndexingMap(&opOperand);
221 AffineMap idMap =
222 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224 AffineMap composeMap = shiftedIdMap.replace(
225 reductionDim, reductionDim * reductionRatio + reductionDimP1,
226 shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227 return map.compose(composeMap);
228}
229
230static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231 unsigned reductionDimPos, int64_t size) {
232 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
233 AffineMap map = op.getMatchingIndexingMap(&opOperand);
234 AffineMap idMap =
235 AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236 AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237 return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238}
239
240/// Core rewrite implementation.
241FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
242 RewriterBase &b, LinalgOp op,
243 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
245 b.setInsertionPoint(op);
246
247 // Matcher part, enforce preconditions.
248 SplitReductionOptions control = controlSplitReductionFn(op);
249 if (control.innerParallel)
250 return b.notifyMatchFailure(op, "innerParallel not supported");
251
252 int64_t splitFactor = control.ratio;
253 unsigned insertSplitDimension = control.index;
254 if (splitFactor <= 1)
255 return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256
258 op.getReductionDims(dims);
259 if (dims.empty())
260 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261
262 unsigned reductionDimPos = dims[0];
263 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264 int64_t reductionDimSize = loopRanges[reductionDimPos];
265 if (reductionDimSize == ShapedType::kDynamic ||
266 reductionDimSize % splitFactor != 0 ||
267 insertSplitDimension >= loopRanges.size())
268 return b.notifyMatchFailure(
269 op, "first reduction dimension not divisible by split factor");
270
271 SmallVector<Operation *> combinerOps;
272 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273 return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274
275 SmallVector<TypedAttr> neutralElements;
276 for (Operation *reductionOp : combinerOps) {
277 std::optional<TypedAttr> neutralElement =
278 arith::getNeutralElement(reductionOp);
279 if (!neutralElement.has_value())
280 return b.notifyMatchFailure(op, "cannot find neutral element.");
281 neutralElements.push_back(*neutralElement);
282 }
283 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284 return b.notifyMatchFailure(op, "unknown reduction neutral");
285
286 // TODO: relax this when multi-reduction support is available.
287 if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288 return b.notifyMatchFailure(op, "expect one reduction per output");
289
290 // Rewrite part.
291 // Step 1. Build the intermediate outputs filled with the proper
292 // neutralElements. Such outputs are of the same shape with an extra dimension
293 // inserted at `insertSplitDimension`.
294 //
295 // Consider a minimal example where `k` is reduced:
296 // O(i, j) += I(i, j, k)
297 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298 // The compute is rewritten as:
299 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300 // b. O(i, j) += O_i(kk, i, j)
301 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302 Location loc = op->getLoc();
303 MLIRContext *context = op.getContext();
304 // For now assume outputs are 1-1 with reduction neutralElements.
305 // TODO: generalize when multi-reduction support is available.
306 SmallVector<Value> newOutputs;
307 newOutputs.reserve(op.getNumDpsInits());
308 SmallVector<Operation *> emptyOrAllocTensorOps;
310 fillOps.reserve(op.getNumDpsInits());
311 for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312 Value rankedTensor = std::get<0>(it).get();
313 auto t = cast<RankedTensorType>(rankedTensor.getType());
314 RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315 reductionDimSize / splitFactor, insertSplitDimension);
316 SmallVector<Value> dims =
317 tensor::createDynamicDimValues(b, loc, rankedTensor);
318 Value emptyOrAllocTensor;
319 if (useAlloc) {
320 emptyOrAllocTensor =
321 bufferization::AllocTensorOp::create(b, loc, newT, dims);
322 } else {
323 emptyOrAllocTensor = tensor::EmptyOp::create(b, loc, newT.getShape(),
324 t.getElementType(), dims);
325 }
326 Value constantOp = arith::ConstantOp::create(b, loc, std::get<1>(it));
327 fillOps.push_back(linalg::FillOp::create(b, op->getLoc(), constantOp,
328 emptyOrAllocTensor));
329 newOutputs.push_back(fillOps.back().getResult(0));
330 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331 }
332
333 // Step 2. Reindex / expand indexing maps.
334 // Reindex existing input indexings: k -> k * splitFactor + k'.
336 newMaps.reserve(op->getNumOperands() + 1);
337 for (OpOperand *o : op.getDpsInputOperands())
338 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339 // Provision a new indexing for the shape-only tensor.
340 auto nDims = op.getNumLoops() + 1;
341 auto redDim = getAffineDimExpr(reductionDimPos, context);
342 auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343 newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344 // Expand existing output indexings.
345 // TODO: a subset of these may not reduce along reducePos and should be
346 // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347 // available.
348 for (OpOperand &o : op.getDpsInitsMutable())
349 newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350 reductionDimSize / splitFactor));
351
352 // Step 3. Handle operands.
353 // Compute the new input tensors.
354 SmallVector<Value> newInputs = op.getDpsInputs();
355 // Add a single shape-only tensor to carry the dimensions without resorting to
356 // more complex inversions.
357 newInputs.push_back(tensor::EmptyOp::create(
358 b, loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359 b.getIntegerType(1)));
360 // Output tensors are already good to go.
361
362 // Step 4. Create the new op matching the original op with an extra parallel
363 // dimension.
364 auto iteratorTypes = op.getIteratorTypesArray();
365 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366 utils::IteratorType::parallel);
367 GenericOp genericOp =
368 GenericOp::create(b, loc, ValueRange(newOutputs).getTypes(), newInputs,
369 newOutputs, newMaps, iteratorTypes);
370 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371 genericOp.getRegion().begin());
372 genericOp.getRegion().front().insertArgument(reductionDimPos,
373 b.getIntegerType(1), loc);
374
375 // Step 5. Create new reduction ops that only reduce the newly added
376 // dimensions from the previous op.
377 // For now assume outputs are 1-1 with reduction ops.
378 // TODO: a subset of these may not reduce in the first place and do not
379 // require a new op, when multi-reduction support is available.
380 // TODO: all results can be handled in a single GenericOp, when
381 // multi-reduction support is available.
382 SmallVector<LinalgOp> results;
383 for (auto it :
384 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385 Value reindexedOutput = std::get<0>(it);
386 Value originalOutput = std::get<1>(it);
387 auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388 Operation *combinerOp = std::get<2>(it);
389
390 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391 SmallVector<AffineMap> indexingMaps = {
392 map, map.dropResult(insertSplitDimension)};
393 SmallVector<utils::IteratorType> reductionIteratorTypes(
394 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395 reductionIteratorTypes[insertSplitDimension] =
396 utils::IteratorType::reduction;
397
398 // clang-format off
399 auto reductionOp = GenericOp::create(b,
400 loc,
401 originalOutputType,
402 reindexedOutput,
403 originalOutput,
404 indexingMaps,
405 reductionIteratorTypes,
406 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407 Operation *clonedReductionOp = b.clone(*combinerOp);
408 clonedReductionOp->setOperand(0, bbArgs[0]);
409 clonedReductionOp->setOperand(1, bbArgs[1]);
410 linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0));
411 });
412 // clang-format on
413
414 results.push_back(reductionOp);
415 }
416
417 // TODO: extend when multi-reduction support is available.
418 assert(fillOps.size() == results.size() && results.size() == 1);
419 b.replaceOp(op, results.front()->getResults());
420 return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421 cast<LinalgOp>(genericOp.getOperation()),
422 results.front()};
423}
424
425namespace {
426
427struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
428 /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
429 LinalgSplitReduction(MLIRContext *context,
430 ControlSplitReductionFn controlSplitReductionFn,
431 bool useAlloc = false, PatternBenefit benefit = 1)
432 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433 controlSplitReductionFn(std::move(controlSplitReductionFn)),
434 useAlloc(useAlloc) {}
435
436 LogicalResult matchAndRewrite(LinalgOp op,
437 PatternRewriter &rewriter) const override {
438 return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
439 }
440
441private:
442 ControlSplitReductionFn controlSplitReductionFn;
443 bool useAlloc;
444};
445
446} // namespace
447
450 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
451 patterns.add<LinalgSplitReduction>(patterns.getContext(),
452 controlSplitReductionFn, useAlloc);
453}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t reductionRatio)
Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) TODO: Additional pattern to rewrite f(i,...
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition AffineMap.h:315
AffineMap dropResult(int64_t pos) const
Returns a new AffineMap with the same number of dims and symbols and one less result at pos,...
Definition AffineMap.h:293
unsigned getNumDims() const
unsigned getNumResults() const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Patterns to apply splitReduction below.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Definition Utils.cpp:62
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Split Reduction options.
Definition Transforms.h:475
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension.