MLIR  19.0.0git
SplitReduction.cpp
Go to the documentation of this file.
1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <utility>
16 
25 #include "mlir/IR/PatternMatch.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
31  RewriterBase &b, LinalgOp op,
32  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
34  b.setInsertionPoint(op);
35 
36  SplitReductionOptions control = controlSplitReductionFn(op);
37  int64_t ratio = control.ratio;
38  unsigned insertSplitIndex = control.index;
39  unsigned insertSplitDimension = control.index;
40  if (ratio <= 1)
41  return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42 
44  op.getReductionDims(dims);
45 
46  if (dims.size() != 1)
47  return b.notifyMatchFailure(op, "needs a single reduction dimension");
48  unsigned reductionDim = dims[0];
49  if (control.innerParallel) {
50  insertSplitDimension = reductionDim + 1;
51  }
52  SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53  int64_t reductionDimSize = loopRanges[reductionDim];
54  if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55  return b.notifyMatchFailure(
56  op, "Reduction dimension not divisible by split ratio");
57  if (op.getNumDpsInits() != 1)
58  return b.notifyMatchFailure(op, "More than one output in split reduction");
59  if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60  return b.notifyMatchFailure(op, "Insert dimension position too large "
61  "compared to intermediate tensor size");
62 
63  SmallVector<Operation *, 4> combinerOps;
64  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65  combinerOps.size() != 1)
66  return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67 
68  Operation *reductionOp = combinerOps[0];
69  std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70  if (!identity.has_value())
71  return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72 
73  Location loc = op->getLoc();
74  SmallVector<Value> newInputs;
75  SmallVector<AffineMap> newMaps;
76  // Calculate the new shapes and indexing maps of the input operands.
77  for (OpOperand *operand : op.getDpsInputOperands()) {
78  AffineMap map = op.getMatchingIndexingMap(operand);
79  SmallVector<int64_t> newShape;
82  unsigned index = 0;
83  for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84  unsigned dim = map.getDimPosition(idx);
85  if (reductionDim == dim) {
86  if (control.innerParallel) {
87  newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88  newShape.push_back(ratio); // parallel (insert)
89  exprs.push_back(
90  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91  exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92  } else {
93  newShape.push_back(ratio); // parallel (insert)
94  newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95  exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96  exprs.push_back(
97  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98  }
99  reassociation.push_back({index++, index++});
100  continue;
101  }
102  newShape.push_back(op.getShape(operand)[idx]);
103  exprs.push_back(
104  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105  reassociation.push_back({index++});
106  }
107  newMaps.push_back(
108  AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109  // If the shape is unchanged the input doesn't change.
110  if (newShape == op.getShape(operand)) {
111  newInputs.push_back(operand->get());
112  continue;
113  }
114  Type newType = RankedTensorType::get(
115  newShape,
116  cast<RankedTensorType>(operand->get().getType()).getElementType());
117  Value newInput = b.create<tensor::ExpandShapeOp>(
118  loc, newType, operand->get(), reassociation);
119  newInputs.push_back(newInput);
120  }
121 
122  // Calculate the new output map and shape, we insert the new dimension based
123  // on the index returned by `controlSplitReductionFn`.
124  SmallVector<int64_t> newOutputShape;
125  AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
126  ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
127  SmallVector<AffineExpr> outputExpr;
128  for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
129  if (insertSplitIndex == idx) {
130  newOutputShape.push_back(ratio);
131  outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
132  }
133  if (idx < oldShape.size()) {
134  newOutputShape.push_back(oldShape[idx]);
135  unsigned dim = oldOutputMap.getDimPosition(idx);
136  outputExpr.push_back(
137  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
138  }
139  }
140  Value emptyOrAllocTensor;
141  if (useAlloc) {
142  emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
143  loc,
144  RankedTensorType::get(newOutputShape,
145  op.getRegionOutputArgs()[0].getType()),
146  ValueRange{});
147  } else {
148  emptyOrAllocTensor = b.create<tensor::EmptyOp>(
149  loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
150  }
151  Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
152  Value identityTensor =
153  b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
154  .getResult(0);
155 
156  newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
157  op.getContext()));
158  SmallVector<utils::IteratorType> newIteratorTypes;
159  for (auto [index, iteratorType] :
160  llvm::enumerate(op.getIteratorTypesArray())) {
161  if (insertSplitDimension == index)
162  newIteratorTypes.push_back(utils::IteratorType::parallel);
163  newIteratorTypes.push_back(iteratorType);
164  }
165  if (insertSplitDimension == op.getIteratorTypesArray().size()) {
166  newIteratorTypes.push_back(utils::IteratorType::parallel);
167  }
168  // Create the new op matching the original op with an extra parallel
169  // dimension.
170  GenericOp genericOp = b.create<GenericOp>(
171  loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
172  ValueRange({identityTensor}), newMaps, newIteratorTypes);
173  b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
174  genericOp.getRegion().begin());
175 
176  // Then create a new reduction that only reduce the newly added dimension
177  // from the previous op.
178  unsigned intermRank = newOutputShape.size();
179  AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
180  SmallVector<utils::IteratorType> reductionIteratorTypes;
182  for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
183  if (insertSplitIndex == i) {
184  reductionIteratorTypes.push_back(utils::IteratorType::reduction);
185  } else {
186  exprs.push_back(b.getAffineDimExpr(i));
187  reductionIteratorTypes.push_back(utils::IteratorType::parallel);
188  }
189  }
190  AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
191  SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
192 
193  auto reduction = b.create<GenericOp>(
194  loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
195  op.getDpsInits(), reductionMaps, reductionIteratorTypes,
196  [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
197  Operation *clonedReductionOp = b.clone(*reductionOp);
198  clonedReductionOp->setOperand(0, inputs[0]);
199  clonedReductionOp->setOperand(1, inputs[1]);
200  b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
201  });
202  b.replaceOp(op, reduction.getResults());
203 
204  return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
205  identityTensor.getDefiningOp<FillOp>(),
206  cast<LinalgOp>(genericOp.getOperation()),
207  reduction};
208 }
209 
210 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
211 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
212 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
213 /// done as a transform to enable better vectorization.
214 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
215  unsigned reductionDimPos,
216  int64_t reductionRatio) {
217  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
218  auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
219  AffineMap map = op.getMatchingIndexingMap(&opOperand);
220  AffineMap idMap =
222  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
223  AffineMap composeMap = shiftedIdMap.replace(
224  reductionDim, reductionDim * reductionRatio + reductionDimP1,
225  shiftedIdMap.getNumDims(), /*numSymbols=*/0);
226  return map.compose(composeMap);
227 }
228 
229 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
230  unsigned reductionDimPos, int64_t size) {
231  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
232  AffineMap map = op.getMatchingIndexingMap(&opOperand);
233  AffineMap idMap =
235  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
236  return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
237 }
238 
239 /// Core rewrite implementation.
241  RewriterBase &b, LinalgOp op,
242  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
243  OpBuilder::InsertionGuard guard(b);
244  b.setInsertionPoint(op);
245 
246  // Matcher part, enforce preconditions.
247  SplitReductionOptions control = controlSplitReductionFn(op);
248  if (control.innerParallel)
249  return b.notifyMatchFailure(op, "innerParallel not supported");
250 
251  int64_t splitFactor = control.ratio;
252  unsigned insertSplitDimension = control.index;
253  if (splitFactor <= 1)
254  return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
255 
257  op.getReductionDims(dims);
258  if (dims.empty())
259  return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
260 
261  unsigned reductionDimPos = dims[0];
262  SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
263  int64_t reductionDimSize = loopRanges[reductionDimPos];
264  if (reductionDimSize == ShapedType::kDynamic ||
265  reductionDimSize % splitFactor != 0 ||
266  insertSplitDimension >= loopRanges.size())
267  return b.notifyMatchFailure(
268  op, "first reduction dimension not divisible by split factor");
269 
270  SmallVector<Operation *> combinerOps;
271  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
272  return b.notifyMatchFailure(op, "cannot match a reduction pattern");
273 
274  SmallVector<TypedAttr> neutralElements;
275  for (Operation *reductionOp : combinerOps) {
276  std::optional<TypedAttr> neutralElement =
277  arith::getNeutralElement(reductionOp);
278  if (!neutralElement.has_value())
279  return b.notifyMatchFailure(op, "cannot find neutral element.");
280  neutralElements.push_back(*neutralElement);
281  }
282  if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
283  return b.notifyMatchFailure(op, "unknown reduction neutral");
284 
285  // TODO: relax this when multi-reduction support is available.
286  if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
287  return b.notifyMatchFailure(op, "expect one reduction per output");
288 
289  // Rewrite part.
290  // Step 1. Build the intermediate outputs filled with the proper
291  // neutralElements. Such outputs are of the same shape with an extra dimension
292  // inserted at `insertSplitDimension`.
293  //
294  // Consider a minimal example where `k` is reduced:
295  // O(i, j) += I(i, j, k)
296  // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
297  // The compute is rewritten as:
298  // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
299  // b. O(i, j) += O_i(kk, i, j)
300  // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
301  Location loc = op->getLoc();
302  MLIRContext *context = op.getContext();
303  // For now assume outputs are 1-1 with reduction neutralElements.
304  // TODO: generalize when multi-reduction support is available.
305  SmallVector<Value> newOutputs;
306  newOutputs.reserve(op.getNumDpsInits());
307  SmallVector<Operation *> emptyOrAllocTensorOps;
309  fillOps.reserve(op.getNumDpsInits());
310  for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
311  Value rankedTensor = std::get<0>(it).get();
312  auto t = cast<RankedTensorType>(rankedTensor.getType());
313  RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
314  reductionDimSize / splitFactor, insertSplitDimension);
315  SmallVector<Value> dims =
316  tensor::createDynamicDimValues(b, loc, rankedTensor);
317  Value emptyOrAllocTensor;
318  if (useAlloc) {
319  emptyOrAllocTensor =
320  b.create<bufferization::AllocTensorOp>(loc, newT, dims);
321  } else {
322  emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
323  t.getElementType(), dims);
324  }
325  Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
326  fillOps.push_back(
327  b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
328  newOutputs.push_back(fillOps.back().getResult(0));
329  emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
330  }
331 
332  // Step 2. Reindex / expand indexing maps.
333  // Reindex existing input indexings: k -> k * splitFactor + k'.
334  SmallVector<AffineMap> newMaps;
335  newMaps.reserve(op->getNumOperands() + 1);
336  for (OpOperand *o : op.getDpsInputOperands())
337  newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
338  // Provision a new indexing for the shape-only tensor.
339  auto nDims = op.getNumLoops() + 1;
340  auto redDim = getAffineDimExpr(reductionDimPos, context);
341  auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
342  newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
343  // Expand existing output indexings.
344  // TODO: a subset of these may not reduce along reducePos and should be
345  // reindexed: k -> k * splitFactor + k', when multi-reduction support is
346  // available.
347  for (OpOperand &o : op.getDpsInitsMutable())
348  newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
349  reductionDimSize / splitFactor));
350 
351  // Step 3. Handle operands.
352  // Compute the new input tensors.
353  SmallVector<Value> newInputs = op.getDpsInputs();
354  // Add a single shape-only tensor to carry the dimensions without resorting to
355  // more complex inversions.
356  newInputs.push_back(b.create<tensor::EmptyOp>(
357  loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
358  b.getIntegerType(1)));
359  // Output tensors are already good to go.
360 
361  // Step 4. Create the new op matching the original op with an extra parallel
362  // dimension.
363  auto iteratorTypes = op.getIteratorTypesArray();
364  iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
365  utils::IteratorType::parallel);
366  GenericOp genericOp =
367  b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
368  newOutputs, newMaps, iteratorTypes);
369  b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
370  genericOp.getRegion().begin());
371  genericOp.getRegion().front().insertArgument(reductionDimPos,
372  b.getIntegerType(1), loc);
373 
374  // Step 5. Create new reduction ops that only reduce the newly added
375  // dimensions from the previous op.
376  // For now assume outputs are 1-1 with reduction ops.
377  // TODO: a subset of these may not reduce in the first place and do not
378  // require a new op, when multi-reduction support is available.
379  // TODO: all results can be handled in a single GenericOp, when
380  // multi-reduction support is available.
381  SmallVector<LinalgOp> results;
382  for (auto it :
383  llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
384  Value reindexedOutput = std::get<0>(it);
385  Value originalOutput = std::get<1>(it);
386  auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
387  Operation *combinerOp = std::get<2>(it);
388 
389  AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
390  SmallVector<AffineMap> indexingMaps = {
391  map, map.dropResult(insertSplitDimension)};
392  SmallVector<utils::IteratorType> reductionIteratorTypes(
393  originalOutputType.getRank() + 1, utils::IteratorType::parallel);
394  reductionIteratorTypes[insertSplitDimension] =
395  utils::IteratorType::reduction;
396 
397  // clang-format off
398  auto reductionOp = b.create<GenericOp>(
399  loc,
400  originalOutputType,
401  reindexedOutput,
402  originalOutput,
403  indexingMaps,
404  reductionIteratorTypes,
405  [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
406  Operation *clonedReductionOp = b.clone(*combinerOp);
407  clonedReductionOp->setOperand(0, bbArgs[0]);
408  clonedReductionOp->setOperand(1, bbArgs[1]);
409  b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
410  });
411  // clang-format on
412 
413  results.push_back(reductionOp);
414  }
415 
416  // TODO: extend when multi-reduction support is available.
417  assert(fillOps.size() == results.size() && results.size() == 1);
418  b.replaceOp(op, results.front()->getResults());
419  return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
420  cast<LinalgOp>(genericOp.getOperation()),
421  results.front()};
422 }
423 
424 namespace {
425 
426 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
427  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
428  LinalgSplitReduction(MLIRContext *context,
429  ControlSplitReductionFn controlSplitReductionFn,
430  bool useAlloc = false, PatternBenefit benefit = 1)
431  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
432  controlSplitReductionFn(std::move(controlSplitReductionFn)),
433  useAlloc(useAlloc) {}
434 
435  LogicalResult matchAndRewrite(LinalgOp op,
436  PatternRewriter &rewriter) const override {
437  return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
438  }
439 
440 private:
441  ControlSplitReductionFn controlSplitReductionFn;
442  bool useAlloc;
443 };
444 
445 } // namespace
446 
448  RewritePatternSet &patterns,
449  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
450  patterns.add<LinalgSplitReduction>(patterns.getContext(),
451  controlSplitReductionFn, useAlloc);
452 }
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t reductionRatio)
Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) TODO: Additional pattern to rewrite f(i,...
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:399
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:260
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:308
AffineMap dropResult(int64_t pos) const
Returns a new AffineMap with the same number of dims and symbols and one less result at pos,...
Definition: AffineMap.h:286
unsigned getNumDims() const
Definition: AffineMap.cpp:378
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:499
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
Attributes are known-constant values of operations.
Definition: Attributes.h:25
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:249
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
Definition: BuiltinTypes.h:283
iterator begin()
Definition: Region.h:55
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2489
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Patterns to apply splitReduction below.
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:443
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Definition: Utils.cpp:43
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
Split Reduction options.
Definition: Transforms.h:428
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension.
Definition: Transforms.h:1011