37 int64_t ratio = control.
ratio;
38 unsigned insertSplitIndex = control.
index;
39 unsigned insertSplitDimension = control.
index;
44 op.getReductionDims(dims);
48 unsigned reductionDim = dims[0];
50 insertSplitDimension = reductionDim + 1;
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
56 op,
"Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
61 "compared to intermediate tensor size");
65 combinerOps.size() != 1)
70 if (!identity.has_value())
77 for (
OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
83 for (
unsigned idx : llvm::seq<unsigned>(0, map.
getNumResults())) {
85 if (reductionDim == dim) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio);
88 newShape.push_back(ratio);
93 newShape.push_back(ratio);
94 newShape.push_back(op.getShape(operand)[idx] / ratio);
99 reassociation.push_back({index++, index++});
102 newShape.push_back(op.getShape(operand)[idx]);
105 reassociation.push_back({index++});
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
118 loc, newType, operand->get(), reassociation);
119 newInputs.push_back(newInput);
125 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
128 for (
unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
129 if (insertSplitIndex == idx) {
130 newOutputShape.push_back(ratio);
133 if (idx < oldShape.size()) {
134 newOutputShape.push_back(oldShape[idx]);
136 outputExpr.push_back(
140 Value emptyOrAllocTensor;
142 emptyOrAllocTensor = b.
create<bufferization::AllocTensorOp>(
145 op.getRegionOutputArgs()[0].getType()),
148 emptyOrAllocTensor = b.
create<tensor::EmptyOp>(
149 loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151 Value constantOp = b.
create<arith::ConstantOp>(loc, *identity);
152 Value identityTensor =
153 b.
create<linalg::FillOp>(op->
getLoc(), constantOp, emptyOrAllocTensor)
159 for (
auto [index, iteratorType] :
161 if (insertSplitDimension == index)
162 newIteratorTypes.push_back(utils::IteratorType::parallel);
163 newIteratorTypes.push_back(iteratorType);
165 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
166 newIteratorTypes.push_back(utils::IteratorType::parallel);
170 GenericOp genericOp = b.
create<GenericOp>(
172 ValueRange({identityTensor}), newMaps, newIteratorTypes);
174 genericOp.getRegion().
begin());
178 unsigned intermRank = newOutputShape.size();
182 for (
unsigned i : llvm::seq<unsigned>(0, intermRank)) {
183 if (insertSplitIndex == i) {
184 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
187 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
193 auto reduction = b.
create<GenericOp>(
195 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
205 identityTensor.getDefiningOp<FillOp>(),
206 cast<LinalgOp>(genericOp.getOperation()),
215 unsigned reductionDimPos,
216 int64_t reductionRatio) {
219 AffineMap map = op.getMatchingIndexingMap(&opOperand);
224 reductionDim, reductionDim * reductionRatio + reductionDimP1,
226 return map.
compose(composeMap);
230 unsigned reductionDimPos, int64_t size) {
232 AffineMap map = op.getMatchingIndexingMap(&opOperand);
251 int64_t splitFactor = control.
ratio;
252 unsigned insertSplitDimension = control.
index;
253 if (splitFactor <= 1)
257 op.getReductionDims(dims);
261 unsigned reductionDimPos = dims[0];
263 int64_t reductionDimSize = loopRanges[reductionDimPos];
264 if (reductionDimSize == ShapedType::kDynamic ||
265 reductionDimSize % splitFactor != 0 ||
266 insertSplitDimension >= loopRanges.size())
268 op,
"first reduction dimension not divisible by split factor");
275 for (
Operation *reductionOp : combinerOps) {
276 std::optional<TypedAttr> neutralElement =
278 if (!neutralElement.has_value())
280 neutralElements.push_back(*neutralElement);
282 if (!llvm::all_of(neutralElements, [](
Attribute attr) {
return attr; }))
286 if (op.getNumDpsInits() !=
static_cast<int64_t
>(neutralElements.size()))
306 newOutputs.reserve(op.getNumDpsInits());
309 fillOps.reserve(op.getNumDpsInits());
310 for (
auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
311 Value rankedTensor = std::get<0>(it).get();
312 auto t = cast<RankedTensorType>(rankedTensor.
getType());
314 reductionDimSize / splitFactor, insertSplitDimension);
317 Value emptyOrAllocTensor;
320 b.
create<bufferization::AllocTensorOp>(loc, newT, dims);
322 emptyOrAllocTensor = b.
create<tensor::EmptyOp>(loc, newT.getShape(),
323 t.getElementType(), dims);
325 Value constantOp = b.
create<arith::ConstantOp>(loc, std::get<1>(it));
327 b.
create<linalg::FillOp>(op->
getLoc(), constantOp, emptyOrAllocTensor));
328 newOutputs.push_back(fillOps.back().getResult(0));
329 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.
getDefiningOp());
336 for (
OpOperand *o : op.getDpsInputOperands())
339 auto nDims = op.getNumLoops() + 1;
342 newMaps.push_back(
AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
347 for (
OpOperand &o : op.getDpsInitsMutable())
349 reductionDimSize / splitFactor));
356 newInputs.push_back(b.
create<tensor::EmptyOp>(
363 auto iteratorTypes = op.getIteratorTypesArray();
364 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
365 utils::IteratorType::parallel);
366 GenericOp genericOp =
368 newOutputs, newMaps, iteratorTypes);
370 genericOp.getRegion().
begin());
371 genericOp.getRegion().front().insertArgument(reductionDimPos,
383 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
384 Value reindexedOutput = std::get<0>(it);
385 Value originalOutput = std::get<1>(it);
386 auto originalOutputType = cast<RankedTensorType>(originalOutput.
getType());
393 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
394 reductionIteratorTypes[insertSplitDimension] =
395 utils::IteratorType::reduction;
398 auto reductionOp = b.
create<GenericOp>(
404 reductionIteratorTypes,
409 b.create<linalg::YieldOp>(loc, clonedReductionOp->
getResult(0));
413 results.push_back(reductionOp);
417 assert(fillOps.size() == results.size() && results.size() == 1);
418 b.
replaceOp(op, results.front()->getResults());
420 cast<LinalgOp>(genericOp.getOperation()),
432 controlSplitReductionFn(std::move(controlSplitReductionFn)),
433 useAlloc(useAlloc) {}
437 return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
451 controlSplitReductionFn, useAlloc);
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t reductionRatio)
Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) TODO: Additional pattern to rewrite f(i,...
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
AffineMap dropResult(int64_t pos) const
Returns a new AffineMap with the same number of dims and symbols and one less result at pos,...
unsigned getNumDims() const
unsigned getNumResults() const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
AffineExpr getAffineDimExpr(unsigned position)
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Patterns to apply splitReduction below.
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension.