MLIR 22.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to remove usage of unit-extent dimensions
10// to specify broadcasting in favor of more canonical representation of the
11// computation
12//
13//===----------------------------------------------------------------------===//
14
16
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
32#include "llvm/Support/Debug.h"
33
34namespace mlir {
35#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
36#include "mlir/Dialect/Linalg/Passes.h.inc"
37} // namespace mlir
38
39#define DEBUG_TYPE "linalg-drop-unit-dims"
40
41using namespace mlir;
42using namespace mlir::linalg;
43
44namespace {
45/// Pattern to move init operands to ins when all the loops are parallel and
46/// blockArgument corresponding to init is used in the region. This is a fix-up
47/// when unit reduction dimensions are all folded away. In this context, it
48/// becomes a elementwise generic op. E.g., it converts
49///
50/// %0 = tensor.empty() : tensor<1x1xf32>
51/// %1 = linalg.fill
52/// ins(%cst : f32)
53/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
54/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
55/// affine_map<(d0) -> (0, d0)>],
56/// iterator_types = ["parallel"]}
57/// ins(%arg0 : tensor<1x?x1x1xf32>)
58/// outs(%1 : tensor<1x1xf32>) {
59/// ^bb0(%in: f32, %out: f32):
60/// %3 = arith.addf %in, %out : f32
61/// linalg.yield %3 : f32
62/// } -> tensor<1x1xf32>
63///
64/// into
65///
66/// %0 = tensor.empty() : tensor<1x1xf32>
67/// %1 = linalg.fill
68/// ins(%cst : f32)
69/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
70/// %2 = tensor.empty() : tensor<1x1xf32>
71/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
72/// affine_map<(d0) -> (0, d0)>,
73/// affine_map<(d0) -> (0, d0)>],
74/// iterator_types = ["parallel"]}
75/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
76/// outs(%2 : tensor<1x1xf32>) {
77/// ^bb0(%in: f32, %in_0: f32, %out: f32):
78/// %4 = arith.addf %in, %in_0 : f32
79/// linalg.yield %4 : f32
80/// } -> tensor<1x1xf32>
81struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
83 LogicalResult matchAndRewrite(GenericOp genericOp,
84 PatternRewriter &rewriter) const override {
85 if (!genericOp.hasPureTensorSemantics())
86 return failure();
87 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
88 return failure();
89
90 auto outputOperands = genericOp.getDpsInitsMutable();
91 SetVector<OpOperand *> candidates;
92 for (OpOperand &op : outputOperands) {
93 if (genericOp.getMatchingBlockArgument(&op).use_empty())
94 continue;
95 candidates.insert(&op);
96 }
97
98 if (candidates.empty())
99 return failure();
100
101 // Compute the modified indexing maps.
102 int64_t origNumInput = genericOp.getNumDpsInputs();
103 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
104 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
105 SmallVector<AffineMap> newIndexingMaps;
106 newIndexingMaps.append(indexingMaps.begin(),
107 std::next(indexingMaps.begin(), origNumInput));
108 for (OpOperand *op : candidates) {
109 newInputOperands.push_back(op->get());
110 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
111 }
112 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
113 indexingMaps.end());
114
115 Location loc = genericOp.getLoc();
116 SmallVector<Value> newOutputOperands =
117 llvm::to_vector(genericOp.getDpsInits());
118 for (OpOperand *op : candidates) {
119 OpBuilder::InsertionGuard guard(rewriter);
120 rewriter.setInsertionPointAfterValue(op->get());
121 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
122 auto empty = tensor::EmptyOp::create(
123 rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()),
124 elemType);
125
126 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
127 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
128 }
129
130 auto newOp = GenericOp::create(
131 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
132 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
133 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
134
135 OpBuilder::InsertionGuard guard(rewriter);
136 Region &region = newOp.getRegion();
137 Block *block = rewriter.createBlock(&region);
138 IRMapping mapper;
139 for (auto bbarg : genericOp.getRegionInputArgs())
140 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
141
142 for (OpOperand *op : candidates) {
143 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
144 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
145 }
146
147 for (OpOperand &op : outputOperands) {
148 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
149 if (candidates.count(&op))
150 block->addArgument(bbarg.getType(), loc);
151 else
152 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
153 }
154
155 for (auto &op : genericOp.getBody()->getOperations()) {
156 rewriter.clone(op, mapper);
157 }
158 rewriter.replaceOp(genericOp, newOp.getResults());
159
160 return success();
161 }
162};
163} // namespace
164
165//===---------------------------------------------------------------------===//
166// Drop loops that are unit-extents within Linalg operations.
167//===---------------------------------------------------------------------===//
168
169/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
170/// broadcasting. For example,
171///
172/// ```mlir
173/// #accesses = [
174/// affine_map<(d0, d1) -> (0, d1)>,
175/// affine_map<(d0, d1) -> (d0, 0)>,
176/// affine_map<(d0, d1) -> (d0, d1)>
177/// ]
178///
179/// #trait = {
180/// indexing_maps = #accesses,
181/// iterator_types = ["parallel", "parallel"],
182/// library_call = "some_external_fn"
183/// }
184///
185/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
186/// tensor<5x5xf32>
187/// {
188/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
189/// tensor<5xf32> into tensor<1x5xf32>
190/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
191/// tensor<5xf32> into tensor<5x1xf32>
192/// %2 = linalg.generic #trait %0, %1 {
193/// ^bb0(%arg2: f32, %arg3: f32):
194/// %3 = arith.addf %arg2, %arg3 : f32
195/// linalg.yield %3 : f32
196/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
197/// return %2 : tensor<5x5xf32>
198/// }
199///
200/// would canonicalize to
201///
202/// ```mlir
203/// #accesses = [
204/// affine_map<(d0, d1) -> (d1)>,
205/// affine_map<(d0, d1) -> (d0)>,
206/// affine_map<(d0, d1) -> (d0, d1)>
207/// ]
208///
209/// #trait = {
210/// indexing_maps = #accesses,
211/// iterator_types = ["parallel", "parallel"],
212/// library_call = "some_external_fn"
213/// }
214///
215/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
216/// tensor<5x5xf32>
217/// {
218/// %0 = linalg.generic #trait %arg0, %arg1 {
219/// ^bb0(%arg2: f32, %arg3: f32):
220/// %3 = arith.addf %arg2, %arg3 : f32
221/// linalg.yield %3 : f32
222/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
223/// return %0 : tensor<5x5xf32>
224/// }
225
226/// Update the index accesses of linalg operations having index semantics.
227static void
228replaceUnitDimIndexOps(GenericOp genericOp,
229 const llvm::SmallDenseSet<unsigned> &unitDims,
230 RewriterBase &rewriter) {
231 for (IndexOp indexOp :
232 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
233 OpBuilder::InsertionGuard guard(rewriter);
234 rewriter.setInsertionPoint(indexOp);
235 if (unitDims.count(indexOp.getDim()) != 0) {
236 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
237 } else {
238 // Update the dimension of the index operation if needed.
239 unsigned droppedDims = llvm::count_if(
240 unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
241 if (droppedDims != 0)
242 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
243 indexOp.getDim() - droppedDims);
244 }
245 }
246}
247
248FailureOr<Value>
249ControlDropUnitDims::expandValue(RewriterBase &rewriter, Location loc,
250 Value result, Value origDest,
251 ArrayRef<ReassociationIndices> reassociation,
252 const ControlDropUnitDims &control) {
253 // There are no results for memref outputs.
254 auto origResultType = cast<RankedTensorType>(origDest.getType());
255 if (origResultType.getEncoding() != nullptr) {
256 // Do not expand tensors with encoding.
257 return failure();
258 }
259 if (control.rankReductionStrategy ==
261 unsigned rank = origResultType.getRank();
262 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
263 SmallVector<OpFoldResult> sizes =
264 tensor::getMixedSizes(rewriter, loc, origDest);
265 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
266 return rewriter.createOrFold<tensor::InsertSliceOp>(
267 loc, result, origDest, offsets, sizes, strides);
268 }
269
270 assert(control.rankReductionStrategy ==
272 "unknown rank reduction strategy");
273 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
274 reassociation)
275 .getResult();
276}
277
278FailureOr<Value>
279ControlDropUnitDims::collapseValue(RewriterBase &rewriter, Location loc,
280 Value operand, ArrayRef<int64_t> targetShape,
281 ArrayRef<ReassociationIndices> reassociation,
282 const ControlDropUnitDims &control) {
283 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
284 if (!memrefType.getLayout().isIdentity()) {
285 // Do not collapse memrefs with a non-identity layout.
286 return failure();
287 }
288 if (control.rankReductionStrategy ==
290 FailureOr<Value> rankReducingExtract =
291 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
292 targetShape);
293 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
294 return *rankReducingExtract;
295 }
296
297 assert(
298 control.rankReductionStrategy ==
300 "unknown rank reduction strategy");
301 MemRefLayoutAttrInterface layout;
302 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
303 layout, memrefType.getMemorySpace());
304 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
305 reassociation)
306 .getResult();
307 }
308 if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
309 if (tensorType.getEncoding() != nullptr) {
310 // Do not collapse tensors with an encoding.
311 return failure();
312 }
313 if (control.rankReductionStrategy ==
315 FailureOr<Value> rankReducingExtract =
316 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
317 targetShape);
318 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
319 return *rankReducingExtract;
320 }
321
322 assert(
323 control.rankReductionStrategy ==
325 "unknown rank reduction strategy");
326 auto targetType =
327 RankedTensorType::get(targetShape, tensorType.getElementType());
328 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
329 reassociation)
330 .getResult();
331 }
332 llvm_unreachable("unsupported operand type");
333}
334
335/// Compute the modified metadata for an operands of operation
336/// whose unit dims are being dropped. Return the new indexing map
337/// to use, the shape of the operand in the replacement op
338/// and the `reassocation` to use to go from original operand shape
339/// to modified operand shape.
346 MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
347 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
348 ArrayRef<AffineExpr> dimReplacements) {
350 ReassociationIndices reassociationGroup;
351 SmallVector<AffineExpr> newIndexExprs;
352 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
353 SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
354 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
355
356 auto isUnitDim = [&](unsigned dim) {
357 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
358 unsigned oldPosition = dimExpr.getPosition();
359 return !oldDimsToNewDimsMap.count(oldPosition) &&
360 (operandShape[dim] == 1);
361 }
362 // Handle the other case where the shape is 1, and is accessed using a
363 // constant 0.
364 if (operandShape[dim] == 1) {
365 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
366 return constAffineExpr && constAffineExpr.getValue() == 0;
367 }
368 return false;
369 };
370
371 unsigned dim = 0;
372 while (dim < operandShape.size() && isUnitDim(dim))
373 reassociationGroup.push_back(dim++);
374 while (dim < operandShape.size()) {
375 assert(!isUnitDim(dim) && "expected non unit-extent");
376 reassociationGroup.push_back(dim);
377 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
378 newIndexExprs.push_back(newExpr);
379 info.targetShape.push_back(operandShape[dim]);
380 ++dim;
381 // Fold all following dimensions that are unit-extent.
382 while (dim < operandShape.size() && isUnitDim(dim)) {
383 reassociationGroup.push_back(dim++);
384 }
385 info.reassociation.push_back(reassociationGroup);
386 reassociationGroup.clear();
387 }
388 info.indexMap =
389 AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
390 newIndexExprs, context);
391 return info;
392}
393
394FailureOr<DropUnitDimsResult>
395linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
396 const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
398 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
399 if (!dpsOp) {
400 return rewriter.notifyMatchFailure(
401 op, "op should implement DestinationStyleOpInterface");
402 }
403
404 SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
405 if (indexingMaps.empty())
406 return failure();
407
408 // 1. Check if any of the iteration dimensions are unit-trip count. They will
409 // end up being unit-trip count if they are used to index into a unit-dim
410 // tensor/memref.
411 AffineMap invertedMap =
412 inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
413 if (!invertedMap) {
414 return rewriter.notifyMatchFailure(op,
415 "invalid indexing maps for operation");
416 }
417
418 SmallVector<int64_t> allShapesSizes;
419 for (OpOperand &opOperand : op->getOpOperands())
420 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
421
422 // 1a. Get the allowed list of dimensions to drop from the `options`.
423 SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
424 if (allowedUnitDims.empty()) {
425 return rewriter.notifyMatchFailure(
426 op, "control function returns no allowed unit dims to prune");
427 }
428 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
429 allowedUnitDims.end());
430 llvm::SmallDenseSet<unsigned> unitDims;
431 for (const auto &expr : enumerate(invertedMap.getResults())) {
432 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
433 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
434 unitDimsFilter.count(expr.index()))
435 unitDims.insert(expr.index());
436 }
437 }
438
439 // 2. Compute the new loops of the modified op by dropping the one-trip
440 // count loops.
441 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
442 SmallVector<AffineExpr> dimReplacements;
443 unsigned newDims = 0;
444 for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
445 if (unitDims.count(index)) {
446 dimReplacements.push_back(
447 getAffineConstantExpr(0, rewriter.getContext()));
448 } else {
449 oldDimToNewDimMap[index] = newDims;
450 dimReplacements.push_back(
451 getAffineDimExpr(newDims, rewriter.getContext()));
452 newDims++;
453 }
454 }
455
456 // 3. For each of the operands, find the
457 // - modified affine map to use.
458 // - shape of the operands after the unit-dims are dropped.
459 // - the reassociation indices used to convert from the original
460 // operand type to modified operand (needed only when using reshapes
461 // for rank reduction strategy)
462 // Note that the indexing maps might need changing even if there are no
463 // unit dimensions that are dropped to handle cases where `0` is used to
464 // access a unit-extent tensor. Consider moving this out of this specific
465 // transformation as a stand-alone transformation. Kept here right now due
466 // to legacy.
467 SmallVector<AffineMap> newIndexingMaps;
470 SmallVector<bool> collapsed;
471 for (OpOperand &opOperand : op->getOpOperands()) {
472 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
473 auto replacementInfo =
474 dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
475 oldDimToNewDimMap, dimReplacements);
476 reassociations.push_back(replacementInfo.reassociation);
477 newIndexingMaps.push_back(replacementInfo.indexMap);
478 targetShapes.push_back(replacementInfo.targetShape);
479 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
480 indexingMap.getNumResults()));
481 }
482
483 // Abort if the indexing maps of the result operation are not invertible
484 // (i.e. not legal) or if no dimension was reduced.
485 if (newIndexingMaps == indexingMaps ||
487 concatAffineMaps(newIndexingMaps, rewriter.getContext())))
488 return failure();
489
490 Location loc = op.getLoc();
491 // 4. For each of the operands, collapse the operand to convert
492 // from original shape to shape in the modified operation if needed,
493 // either through use of reshapes or rank-reducing slices as
494 // specified in `options`.
495 // Abort if one of the operands cannot be collapsed.
496 SmallVector<Value> newOperands;
497 for (OpOperand &opOperand : op->getOpOperands()) {
498 int64_t idx = opOperand.getOperandNumber();
499 if (!collapsed[idx]) {
500 newOperands.push_back(opOperand.get());
501 continue;
502 }
503 FailureOr<Value> collapsed =
504 options.collapseFn(rewriter, loc, opOperand.get(), targetShapes[idx],
505 reassociations[idx], options);
506 if (failed(collapsed)) {
507 // Abort if the operand could not be collapsed.
508 return failure();
509 }
510 newOperands.push_back(collapsed.value());
511 }
512
513 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
514 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
515
516 // 6. If any result type changes, insert a reshape/slice to convert from the
517 // original type to the new type.
518 // Abort the transformation if the result cannot be expanded back to its
519 // original shape.
520 SmallVector<Value> resultReplacements;
521 for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
522 unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
523 Value origDest = dpsOp.getDpsInitOperand(index)->get();
524 if (!collapsed[opOperandIndex]) {
525 resultReplacements.push_back(result);
526 continue;
527 }
528 FailureOr<Value> expanded =
529 options.expandFn(rewriter, loc, result, origDest,
530 reassociations[opOperandIndex], options);
531 if (failed(expanded)) {
532 // Abort if expansion is not successful.
533 return failure();
534 }
535 resultReplacements.push_back(expanded.value());
536 }
537
538 return DropUnitDimsResult{replacementOp, resultReplacements};
539}
540
541FailureOr<DropUnitDimsResult>
542linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
544
546 [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
547 ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
548 const llvm::SmallDenseSet<unsigned> &droppedDims)
549 -> IndexingMapOpInterface {
550 auto genericOp = cast<GenericOp>(op);
551 // Compute the iterator types of the modified op by dropping the one-trip
552 // count loops.
553 SmallVector<utils::IteratorType> newIteratorTypes;
554 for (auto [index, attr] :
555 llvm::enumerate(genericOp.getIteratorTypesArray())) {
556 if (!droppedDims.count(index))
557 newIteratorTypes.push_back(attr);
558 }
559
560 // Create the `linalg.generic` operation with the new operands,
561 // indexing maps, iterator types and result types.
562 ArrayRef<Value> newInputs =
563 ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
564 ArrayRef<Value> newOutputs =
565 ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
566 SmallVector<Type> resultTypes;
567 resultTypes.reserve(genericOp.getNumResults());
568 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
569 resultTypes.push_back(newOutputs[i].getType());
570 GenericOp replacementOp =
571 GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
572 newIndexingMaps, newIteratorTypes);
573 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
574 replacementOp.getRegion().begin());
575 // 5a. Replace `linalg.index` operations that refer to the dropped unit
576 // dimensions.
577 IRRewriter rewriter(b);
578 replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
579
580 return replacementOp;
581 };
582
583 return dropUnitDims(rewriter, genericOp, build, options);
584}
585
586namespace {
587struct DropUnitDims : public OpRewritePattern<GenericOp> {
588 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
589 PatternBenefit benefit = 1)
590 : OpRewritePattern(context, benefit), options(std::move(options)) {}
591
592 LogicalResult matchAndRewrite(GenericOp genericOp,
593 PatternRewriter &rewriter) const override {
594 FailureOr<DropUnitDimsResult> result =
595 dropUnitDims(rewriter, genericOp, options);
596 if (failed(result)) {
597 return failure();
598 }
599 rewriter.replaceOp(genericOp, result->replacements);
600 return success();
601 }
602
603private:
604 ControlDropUnitDims options;
605};
606} // namespace
607
608//===---------------------------------------------------------------------===//
609// Drop dimensions that are unit-extents within tensor operations.
610//===---------------------------------------------------------------------===//
611
612namespace {
613struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
614 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
615 PatternBenefit benefit = 1)
616 : OpRewritePattern(context, benefit), options(std::move(options)) {}
617
618 LogicalResult matchAndRewrite(tensor::PadOp padOp,
619 PatternRewriter &rewriter) const override {
620 // 1a. Get the allowed list of dimensions to drop from the `options`.
621 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
622 if (allowedUnitDims.empty()) {
623 return rewriter.notifyMatchFailure(
624 padOp, "control function returns no allowed unit dims to prune");
625 }
626
627 if (padOp.getSourceType().getEncoding()) {
628 return rewriter.notifyMatchFailure(
629 padOp, "cannot collapse dims of tensor with encoding");
630 }
631
632 // Fail for non-constant padding values. The body of the pad could
633 // depend on the padding indices and/or properties of the padded
634 // tensor so for now we fail.
635 // TODO: Support non-constant padding values.
636 Value paddingVal = padOp.getConstantPaddingValue();
637 if (!paddingVal) {
638 return rewriter.notifyMatchFailure(
639 padOp, "unimplemented: non-constant padding value");
640 }
641
642 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
643 ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
644 int64_t padRank = sourceShape.size();
645
646 auto isStaticZero = [](OpFoldResult f) {
647 return getConstantIntValue(f) == 0;
648 };
649
650 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
651 allowedUnitDims.end());
652 llvm::SmallDenseSet<unsigned> unitDims;
653 SmallVector<int64_t> newShape;
654 SmallVector<int64_t> newResultShape;
655 SmallVector<OpFoldResult> newLowPad;
656 SmallVector<OpFoldResult> newHighPad;
657 for (const auto [dim, size, outSize, low, high] : zip_equal(
658 llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
659 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
660 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
661 isStaticZero(high)) {
662 unitDims.insert(dim);
663 } else {
664 newShape.push_back(size);
665 newResultShape.push_back(outSize);
666 newLowPad.push_back(low);
667 newHighPad.push_back(high);
668 }
669 }
670
671 if (unitDims.empty()) {
672 return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
673 }
674
675 ReassociationIndices reassociationGroup;
676 SmallVector<ReassociationIndices> reassociationMap;
677 int64_t dim = 0;
678 while (dim < padRank && unitDims.contains(dim))
679 reassociationGroup.push_back(dim++);
680 while (dim < padRank) {
681 assert(!unitDims.contains(dim) && "expected non unit-extent");
682 reassociationGroup.push_back(dim);
683 dim++;
684 // Fold all following dimensions that are unit-extent.
685 while (dim < padRank && unitDims.contains(dim))
686 reassociationGroup.push_back(dim++);
687 reassociationMap.push_back(reassociationGroup);
688 reassociationGroup.clear();
689 }
690
691 FailureOr<Value> collapsedSource =
692 options.collapseFn(rewriter, padOp.getLoc(), padOp.getSource(),
693 newShape, reassociationMap, options);
694 if (failed(collapsedSource)) {
695 return rewriter.notifyMatchFailure(padOp, "Failed to collapse source");
696 }
697
698 auto newResultType = RankedTensorType::get(
699 newResultShape, padOp.getResultType().getElementType());
700 auto newPadOp = tensor::PadOp::create(
701 rewriter, padOp.getLoc(), /*result=*/newResultType,
702 collapsedSource.value(), newLowPad, newHighPad, paddingVal,
703 padOp.getNofold());
704
705 Value dest = padOp.getResult();
706 if (options.rankReductionStrategy ==
707 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
708 SmallVector<OpFoldResult> expandedSizes;
709 int64_t numUnitDims = 0;
710 for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
711 if (unitDims.contains(dim)) {
712 expandedSizes.push_back(rewriter.getIndexAttr(1));
713 numUnitDims++;
714 continue;
715 }
716 expandedSizes.push_back(tensor::getMixedSize(
717 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
719 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
720 padOp.getResultType().getElementType());
721 }
723 FailureOr<Value> expandedValue =
724 options.expandFn(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
725 reassociationMap, options);
726 if (failed(expandedValue)) {
727 return rewriter.notifyMatchFailure(padOp, "Failed to expand result");
728 }
729 rewriter.replaceOp(padOp, expandedValue.value());
730 return success();
731 }
732
733private:
736} // namespace
737
738namespace {
739/// Convert `extract_slice` operations to rank-reduced versions.
740struct RankReducedExtractSliceOp
741 : public OpRewritePattern<tensor::ExtractSliceOp> {
742 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
743
744 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
745 PatternRewriter &rewriter) const override {
746 RankedTensorType resultType = sliceOp.getType();
747 SmallVector<OpFoldResult> targetShape;
748 for (auto size : resultType.getShape())
749 targetShape.push_back(rewriter.getIndexAttr(size));
750 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
751 if (!reassociation ||
752 reassociation->size() == static_cast<size_t>(resultType.getRank()))
753 return failure();
755 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
756 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
757 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
758 auto rankReducedType = cast<RankedTensorType>(
759 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
760 reassociation->size(), sliceOp.getSourceType(), sizes));
761
762 Location loc = sliceOp.getLoc();
763 Value newSlice = tensor::ExtractSliceOp::create(
764 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
765 strides);
766 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
767 sliceOp, resultType, newSlice, *reassociation);
768 return success();
770};
771
772/// Convert `insert_slice` operations to rank-reduced versions.
773/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
774template <typename InsertOpTy>
775struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
777
778 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
779 PatternRewriter &rewriter) const override {
780 RankedTensorType sourceType = insertSliceOp.getSourceType();
781 SmallVector<OpFoldResult> targetShape;
782 for (auto size : sourceType.getShape())
783 targetShape.push_back(rewriter.getIndexAttr(size));
784 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
785 if (!reassociation ||
786 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
787 return failure();
788
789 Location loc = insertSliceOp.getLoc();
790 tensor::CollapseShapeOp reshapedSource;
791 {
792 OpBuilder::InsertionGuard g(rewriter);
793 // The only difference between InsertSliceOp and ParallelInsertSliceOp
794 // is the insertion point is just before the ParallelCombiningOp in the
795 // parallel case.
796 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
797 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
798 reshapedSource = tensor::CollapseShapeOp::create(
799 rewriter, loc, insertSliceOp.getSource(), *reassociation);
800 }
801 rewriter.replaceOpWithNewOp<InsertOpTy>(
802 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
803 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
804 insertSliceOp.getMixedStrides());
805 return success();
806 }
807};
808} // namespace
809
810/// Patterns that are used to canonicalize the use of unit-extent dims for
811/// broadcasting.
814 auto *context = patterns.getContext();
815 patterns.add<DropUnitDims>(context, options);
816 patterns.add<DropPadUnitDims>(context, options);
817}
818
821 auto *context = patterns.getContext();
822 bool reassociativeReshape =
823 options.rankReductionStrategy ==
825 if (reassociativeReshape) {
826 patterns.add<RankReducedExtractSliceOp,
827 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
828 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
829 context);
830 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
831 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
832 }
833 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
834 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
838}
839
842 patterns.add<MoveInitOperandsToInput>(patterns.getContext());
843}
844
845namespace {
846/// Pass that removes unit-extent dims within generic ops.
847struct LinalgFoldUnitExtentDimsPass
849 LinalgFoldUnitExtentDimsPass> {
851 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
852 void runOnOperation() override {
853 Operation *op = getOperation();
854 MLIRContext *context = op->getContext();
856 if (useRankReducingSlices) {
857 options.rankReductionStrategy = linalg::ControlDropUnitDims::
858 RankReductionStrategy::ExtractInsertSlice;
859 }
860
861 // Apply fold unit extent dims patterns with walk-based driver.
862 {
865 walkAndApplyPatterns(op, std::move(patterns));
866 }
867
868 // Apply canonicalization patterns with greedy driver.
869 {
870 RewritePatternSet patterns(context);
873 options);
874 (void)applyPatternsGreedily(op, std::move(patterns));
875 }
876 }
877};
878
879} // namespace
880
881namespace {
882
883/// Returns reassociation indices for collapsing/expanding a
884/// tensor of rank `rank` at position `pos`.
885static SmallVector<ReassociationIndices>
886getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
887 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
888 bool lastDim = pos == rank - 1;
889 if (rank > 2) {
890 for (int64_t i = 0; i < rank - 1; i++) {
891 if (i == pos || (lastDim && i == pos - 1))
892 reassociation[i] = ReassociationIndices{i, i + 1};
893 else if (i < pos)
894 reassociation[i] = ReassociationIndices{i};
895 else
896 reassociation[i] = ReassociationIndices{i + 1};
897 }
898 }
899 return reassociation;
900}
901
902/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
903/// If `pos < 0`, then don't collapse.
904static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
905 int64_t pos) {
906 if (pos < 0)
907 return val;
908 auto valType = cast<ShapedType>(val.getType());
909 SmallVector<int64_t> collapsedShape(valType.getShape());
910 collapsedShape.erase(collapsedShape.begin() + pos);
911 ControlDropUnitDims control{};
912 FailureOr<Value> collapsed = control.collapseFn(
913 rewriter, val.getLoc(), val, collapsedShape,
914 getReassociationForReshapeAtDim(valType.getRank(), pos), control);
915 assert(llvm::succeeded(collapsed) && "Collapsing the value failed");
916 return collapsed.value();
917}
918
919/// Base class for all rank reduction patterns for contraction ops
920/// with unit dimensions. All patterns should convert one named op
921/// to another named op. Intended to reduce only one iteration space dim
922/// at a time.
923/// Reducing multiple dims will happen with recusive application of
924/// pattern rewrites.
925template <typename FromOpTy, typename ToOpTy>
926struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
927 using OpRewritePattern<FromOpTy>::OpRewritePattern;
928
929 /// Collapse all collapsable operands.
930 SmallVector<Value>
931 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
932 ArrayRef<int64_t> operandCollapseDims) const {
933 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
934 "expected 3 operands and dims");
935 return llvm::map_to_vector(
936 llvm::zip(operands, operandCollapseDims), [&](auto pair) {
937 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
938 std::get<1>(pair));
939 });
940 }
941
942 /// Expand result tensor.
943 Value expandResult(PatternRewriter &rewriter, Value result,
944 RankedTensorType expandedType, int64_t dim) const {
945 return tensor::ExpandShapeOp::create(
946 rewriter, result.getLoc(), expandedType, result,
947 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
948 }
949
950 LogicalResult matchAndRewrite(FromOpTy contractionOp,
951 PatternRewriter &rewriter) const override {
952 if (contractionOp.hasUserDefinedMaps()) {
953 return rewriter.notifyMatchFailure(
954 contractionOp, "ops with user-defined maps are not supported");
955 }
956
957 auto loc = contractionOp.getLoc();
958 auto inputs = contractionOp.getDpsInputs();
959 auto inits = contractionOp.getDpsInits();
960 if (inputs.size() != 2 || inits.size() != 1)
961 return rewriter.notifyMatchFailure(contractionOp,
962 "expected 2 inputs and 1 init");
963 auto lhs = inputs[0];
964 auto rhs = inputs[1];
965 auto init = inits[0];
966 SmallVector<Value> operands{lhs, rhs, init};
967
968 SmallVector<int64_t> operandUnitDims;
969 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
970 return rewriter.notifyMatchFailure(contractionOp,
971 "no reducable dims found");
972
973 SmallVector<Value> collapsedOperands =
974 collapseOperands(rewriter, operands, operandUnitDims);
975 Value collapsedLhs = collapsedOperands[0];
976 Value collapsedRhs = collapsedOperands[1];
977 Value collapsedInit = collapsedOperands[2];
978 SmallVector<Type, 1> collapsedResultTy;
979 if (isa<RankedTensorType>(collapsedInit.getType()))
980 collapsedResultTy.push_back(collapsedInit.getType());
981 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
982 ValueRange{collapsedLhs, collapsedRhs},
983 ValueRange{collapsedInit});
984 for (auto attr : contractionOp->getAttrs()) {
985 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
986 attr.getName() == "indexing_maps")
987 continue;
988 collapsedOp->setAttr(attr.getName(), attr.getValue());
989 }
990
991 auto results = contractionOp.getResults();
992 assert(results.size() < 2 && "expected at most one result");
993 if (results.empty()) {
994 rewriter.replaceOp(contractionOp, collapsedOp);
995 } else {
996 rewriter.replaceOp(
997 contractionOp,
998 expandResult(rewriter, collapsedOp.getResultTensors()[0],
999 cast<RankedTensorType>(results[0].getType()),
1000 operandUnitDims[2]));
1001 }
1002
1003 return success();
1004 }
1005
1006 /// Populate `operandUnitDims` with 3 indices indicating the unit dim
1007 /// for each operand that should be collapsed in this pattern. If an
1008 /// operand shouldn't be collapsed, the index should be negative.
1009 virtual LogicalResult
1010 getOperandUnitDims(LinalgOp op,
1011 SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
1012};
1013
1014/// Patterns for unbatching batched contraction ops
1015template <typename FromOpTy, typename ToOpTy>
1016struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1017 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1018
1019 /// Look for unit batch dims to collapse.
1020 LogicalResult
1021 getOperandUnitDims(LinalgOp op,
1022 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1023 FailureOr<ContractionDimensions> maybeContractionDims =
1025 if (failed(maybeContractionDims)) {
1026 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1027 return failure();
1028 }
1029 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1030
1031 if (contractionDims.batch.size() != 1)
1032 return failure();
1033 auto batchDim = contractionDims.batch[0];
1034 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
1035 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1036 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
1037 return cast<ShapedType>(std::get<0>(pair).getType())
1038 .getShape()[std::get<1>(pair)] != 1;
1039 })) {
1040 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1041 return failure();
1042 }
1043
1044 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1045 std::get<1>(bOperands[1]),
1046 std::get<1>(bOperands[2])};
1047 return success();
1048 }
1049};
1050
1051/// Patterns for reducing non-batch dimensions
1052template <typename FromOpTy, typename ToOpTy>
1053struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1054 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1055
1056 /// Helper for determining whether the lhs/init or rhs/init are reduced.
1057 static bool constexpr reduceLeft =
1058 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1059 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1060 (std::is_same_v<FromOpTy, MatmulOp> &&
1061 std::is_same_v<ToOpTy, VecmatOp>) ||
1062 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1063
1064 /// Look for non-batch spatial dims to collapse.
1065 LogicalResult
1066 getOperandUnitDims(LinalgOp op,
1067 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1068 FailureOr<ContractionDimensions> maybeContractionDims =
1070 if (failed(maybeContractionDims)) {
1071 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1072 return failure();
1073 }
1074 const ContractionDimensions &contractionDims = maybeContractionDims.value();
1075
1076 if constexpr (reduceLeft) {
1077 auto m = contractionDims.m[0];
1078 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1079 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1080 if (mOperands.size() != 2)
1081 return failure();
1082 if (llvm::all_of(mOperands, [](auto pair) {
1083 return cast<ShapedType>(std::get<0>(pair).getType())
1084 .getShape()[std::get<1>(pair)] == 1;
1085 })) {
1086 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1087 std::get<1>(mOperands[1])};
1088 return success();
1089 }
1090 } else {
1091 auto n = contractionDims.n[0];
1092 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1093 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1094 if (nOperands.size() != 2)
1095 return failure();
1096 if (llvm::all_of(nOperands, [](auto pair) {
1097 return cast<ShapedType>(std::get<0>(pair).getType())
1098 .getShape()[std::get<1>(pair)] == 1;
1099 })) {
1100 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1101 std::get<1>(nOperands[1])};
1102 return success();
1103 }
1104 }
1105 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1106 return failure();
1107 }
1108};
1109
1110} // namespace
1111
1114 MLIRContext *context = patterns.getContext();
1115 // Unbatching patterns for unit batch size
1116 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1117 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1118 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1119
1120 // Non-batch rank 1 reducing patterns
1121 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1122 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1123 // Batch rank 1 reducing patterns
1124 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1125 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1126
1127 // Non-batch rank 0 reducing patterns
1128 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1129 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1130}
return success()
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
std::function< IndexingMapOpInterface( Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
Definition Transforms.h:629
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2555
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
Drop unit extent dimensions from the op and its operands.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:388
void populateFoldUnitExtentDimsCanonicalizationPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Populates canonicalization patterns that simplify IR after folding unit-extent dimensions.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
SmallVector< unsigned, 2 > batch
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
RankReductionStrategy rankReductionStrategy
Definition Transforms.h:524
CollapseFnTy collapseFn
Function to control how operands are collapsed into their new target shape after dropping unit extent...
Definition Transforms.h:568