MLIR 22.0.0git
TosaReduceTransposes.cpp
Go to the documentation of this file.
1//===- TosaReduceTransposes.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// ----------
10// Motivation:
11// ----------
12
13// Some legalization pathways introduce redundant tosa.TRANSPOSE
14// operations that result in avoidable data movement. For example,
15// PyTorch -> TOSA contains a lot of unnecessary transposes due
16// to conversions between NCHW and NHWC.
17
18// We wish to remove all the ones that we can, since in general
19// it is possible to remove the overwhelming majority.
20
21// -------------------
22// High-Level Overview:
23// -------------------
24
25// The pass works through the transpose operators in the program. It begins at
26// some transpose operator with an associated permutations tensor. It traverses
27// upwards through the dependencies of this transpose and verifies that we
28// encounter only operators with the TosaElementwiseOperator trait and terminate
29// in either constants, reshapes, or transposes.
30
31// We then evaluate whether there are any additional restrictions (the
32// transposes it terminates in must invert the one we began at, and the reshapes
33// must be ones in which we can fold the transpose into), and then we hoist the
34// transpose through the intervening operators, folding it at the constants,
35// reshapes, and transposes.
36
37// Finally, we ensure that we do not need both the transposed form (the form
38// that had the transpose hoisted through it) and the untransposed form (which
39// it was prior), by analyzing the usages of those dependent operators of a
40// given transpose we are attempting to hoist and replace.
41
42// If they are such that it would require both forms to be necessary, then we do
43// not replace the hoisted transpose, causing the new chain to be dead.
44// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45// chain will ever then be live, resulting in no duplication.
46
47// We then perform a simple one-pass DCE, so no canonicalization is necessary.
48
49// -----------
50// Future Work:
51// -----------
52
53// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54// hoisted
55// transposes with different permutation tensors.
56
57// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58// N -> 1x1x...x1xNx1x...x1x1.
59
60// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61// those that form the identity.
62
63// (4) Add support for more instructions besides TosaElementwiseOperator as
64// the intervening ones (for example, the reduce_* operators).
65
66// (5) Support hoisting transposes up to an input parameter.
67
68//===----------------------------------------------------------------------===//
69
74#include "mlir/IR/Iterators.h"
75#include "llvm/ADT/TypeSwitch.h"
76#include <set>
77#include <stack>
78
79namespace mlir {
80namespace tosa {
81#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
82#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
83} // namespace tosa
84} // namespace mlir
85
86using namespace mlir;
87using namespace mlir::tosa;
88
89//===----------------------------------------------------------------------===//
90// TOSA Reduce Transposes Pass.
91//===----------------------------------------------------------------------===//
92
93namespace {
94
95struct TosaReduceTransposes final
96 : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
97 void runOnOperation() override;
98
99private:
100 // This will collect all the data dependencies for the given Operation
101 // up to and including ConstOp, ReshapeOp, and TransposeOp.
102 bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
103 bool convertDependentOps(SetVector<Operation *> &dependentOps,
104 DenseMap<Value, Value> &valuesMap,
105 IRRewriter &rewriter,
106 ArrayRef<int32_t> hoistedPerms);
107
108 // Checks if the two permutations, when applied consecutively, result
109 // in the identity.
110 bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
111 ArrayRef<int32_t> perms2);
112
113 // This is meant to apply to operations with the TosaElementwiseOperator
114 // trait.
115 std::optional<Value>
116 buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
117 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
118
119 // This updates valuesMap when we encounter another TransposeOp as a
120 // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
121 // this %1 = tosa.transpose %0 <- when tracking back from this
122 std::optional<Value>
123 buildMappedToValue(TransposeOp transposeOp,
124 const DenseMap<Value, Value> &valuesMap,
125 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
126
127 // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
128 // it creates new ReshapeOp with that fold.
129 std::optional<Value>
130 buildMappedToValue(ReshapeOp reshapeOp,
131 const DenseMap<Value, Value> &valuesMap,
132 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
133
134 // We may have something like:
135 // %0 = tosa.const
136 // %1 = tosa.transpose
137 // %2 = tosa.add %0, %1
138 // %3 = tosa.transpose %2
139 // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
140 // in MobilenetV3.
141 std::optional<Value>
142 buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
143 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
144
145 // Checks which TransposeOp we should "replace", turning their converted
146 // chains of ops, through which they were propagated, "live", and the old code
147 // "dead." Attempts to avoid doing so when doing so would result in the old
148 // code staying "live," resulting in duplication.
149 std::set<TransposeOp> getGoodReplacements(
150 ArrayRef<int32_t> perms,
151 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
152 &transposeInfo);
153
154 // Helper function for dependenciesAreValid.
155 bool userNotContainedInValidTransposeDependencies(
156 Operation *user, std::set<TransposeOp> &validTransposes,
157 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
158 &transposeInfo);
159
160 // Helper function for getGoodReplacements to check if some TransposeOp's
161 // dependencies are OK.
162 bool dependenciesAreValid(
163 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
164 std::set<TransposeOp> &validTransposes,
165 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
166 &transposeInfo);
167
168 // Applies perms to the DenseElementsAttr.
169 // If it returns std::nullopt, it also triggers pass failure, since verifier
170 // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
171 // it should fail).
172 // This is a basic API and may benefit from refactor into the core MLIR APIs.
173 std::optional<DenseElementsAttr>
174 transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
175};
176
177std::optional<DenseElementsAttr>
178TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
179 ArrayRef<int32_t> perms) {
180 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
181 ArrayRef<int64_t> oldShape = oldType.getShape();
182 int64_t rank = oldType.getRank();
183
184 // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
185 // 0. If not in place, something is very wrong.
186 if (rank <= 0 || oldType.getNumElements() <= 0) {
187 signalPassFailure();
188 return std::nullopt;
189 }
190
191 auto newShape = applyTOSAPermutation(oldShape, perms);
192 RankedTensorType newType =
193 RankedTensorType::get(newShape, oldType.getElementType());
194
195 if (input.isSplat()) {
196 return input.reshape(newType);
197 }
198
199 auto rawData = input.getRawData();
200 if (!rawData.data()) {
201 return std::nullopt;
202 }
203
204 // The algorithm is approximately as follows:
205 // 1. Determine the strides of both input and output tensors in row-major
206 // order
207 // 2. Iterate through the output tensor linearly.
208 // 3. For each output position, decompose the linear index into
209 // multi-dimensional coordinates using output strides.
210 // 4. Use the permutation to map output coordinates to input coordinates and
211 // calculate the source linear index.
212
213 // Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
214 // for output linear index 11: decompose to output[1][1][2]
215 // using output strides [6,3,1]. Map to input coordinates using
216 // perms: dim 0→2, dim 1→0, dim 2→1, giving source position
217 // calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
218 // = 1*1 + 1*12 + 2*4 = 21
219
220 size_t elementSize = oldType.getElementTypeBitWidth() / 8;
221 int64_t numElements = oldType.getNumElements();
222
223 SmallVector<char> outputBuffer(numElements * elementSize);
224 const char *inputPtr = rawData.data();
225 char *outputPtr = outputBuffer.data();
226
227 auto calculateStrides = [](ArrayRef<int64_t> shape) -> SmallVector<int64_t> {
228 int64_t rank = shape.size();
229 SmallVector<int64_t> strides(rank);
230 strides[rank - 1] = 1;
231 for (int64_t i = rank - 2; i >= 0; --i) {
232 strides[i] = strides[i + 1] * shape[i + 1];
233 }
234 return strides;
235 };
236
237 // Calculate strides for both input and output tensors
238 SmallVector<int64_t> inputStrides = calculateStrides(oldShape);
239 SmallVector<int64_t> outputStrides = calculateStrides(newShape);
240
241 auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242 int64_t tempDestIndex = destLinearIndex;
243 int64_t sourceLinearIndex = 0;
244
245 // Decompose linear destination index into multi-dimensional
246 // coordinates dividing by output strides.
247 // Simultaneously map these coordinates through the permutation
248 // to calculate the corresponding source linear index.
249 for (auto j : llvm::seq<int64_t>(rank)) {
250 int64_t destCoord = tempDestIndex / outputStrides[j];
251 tempDestIndex %= outputStrides[j];
252 sourceLinearIndex += destCoord * inputStrides[perms[j]];
253 }
254
255 return sourceLinearIndex;
256 };
257
258 for (auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
259 int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);
260
261 // Copy the element from source to destination using type-agnostic byte
262 // copying.
263 std::memcpy(outputPtr + destLinearIndex * elementSize,
264 inputPtr + sourceLinearIndex * elementSize, elementSize);
265 }
266
267 return DenseElementsAttr::getFromRawBuffer(newType, outputBuffer);
268}
269
270// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
271// as the sources of the data dependencies, and TosaElementWiseOperator
272// after that, if the function returns true.
273bool TosaReduceTransposes::collectFanIn(Operation *op,
274 SetVector<Operation *> &collected) {
275 // Can occur if defined through the parameter to a func.func.
276 if (!op)
277 return false;
278
279 if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
280 return false;
281
282 // Prevent extra work if already seen.
283 if (collected.contains(op))
284 return true;
285
286 // Throw it out so later don't have to deal with this.
287 if (op->getNumResults() != 1 ||
288 !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
289 return false;
290
291 // We don't wish to traverse up a ReshapeOp, since generally we can't
292 // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
293 // will have no in-edges in the data dependency graph we construct for
294 // the downstream TransposeOp.
295 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
296 !llvm::isa<tosa::ConstOp>(op)) {
297
298 if (!llvm::isa<tosa::MulOp>(op) &&
299 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
300 return false;
301
302 for (Value operand : op->getOperands()) {
303 // If this is a problem in future, think about alternatives to recursion.
304 if (llvm::isa<tosa::MulOp>(op) && operand == op->getOperand(2)) {
305 // do not recurse into MulOp's shift operand
306 continue;
307 }
308 if (!collectFanIn(operand.getDefiningOp(), collected))
309 return false;
310 }
311 }
312
313 // Insert in topological order.
314 collected.insert(op);
315
316 return true;
317}
318
319// Assuming that due to the verification of TransposeOp perms arrays are
320// permutations of 0 - perms.size() - 1.
321bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
322 ArrayRef<int32_t> perms2) {
323 if (perms1.size() != perms2.size())
324 return false;
325 int32_t n = perms1.size();
326 for (int32_t i = 0; i < n; i++)
327 if (perms2[perms1[i]] != i)
328 return false;
329 return true;
330}
331
332// Primary overload for those with TosaElementwiseOperator trait.
333// The other ones handle the case of the operations that occur at the
334// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
335std::optional<Value> TosaReduceTransposes::buildMappedToValue(
336 Operation *op, const DenseMap<Value, Value> &valuesMap,
337 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
338 if (op->getNumResults() != 1 ||
339 (!llvm::isa<tosa::MulOp>(op) &&
340 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()))
341 return std::nullopt;
342
343 auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
344 SmallVector<Value, 3> operands;
345 for (Value v : op->getOperands()) {
346 if (valuesMap.contains(v)) {
347 operands.push_back(valuesMap.at(v));
348 } else if (llvm::isa<tosa::MulOp>(op) && v == op->getOperand(2)) {
349 // special case for MulOp's shift operand
350 operands.push_back(v);
351 } else {
352 return std::nullopt;
353 }
354 }
355
356 // Conceptually, we propagate the hoisted TransposeOp through
357 // these interveaning operations. For example,
358
359 // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
360 // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
361 // tensor<3x2xi32>
362
363 // becomes:
364 // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
365 // tensor<3x2xi32>
366 // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
367
368 // We construct this new tosa.clamp here, but it doesn't
369 // turn "live" until the transpose being hoisted through this chain
370 // is replaced with the proper value from the new chain.
371
372 return rewriter
373 .create(op->getLoc(), op->getName().getIdentifier(), operands,
374 RankedTensorType::get(
375 applyTOSAPermutation(resultType.getShape(), hoistedPerms),
376 resultType.getElementType()),
377 op->getAttrs())
378 ->getResult(0);
379}
380
381std::optional<Value> TosaReduceTransposes::buildMappedToValue(
382 TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
383 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
384 if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
385 return std::nullopt;
386 return transposeOp.getInput1();
387}
388
389std::optional<Value> TosaReduceTransposes::buildMappedToValue(
390 ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
391 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
392 auto reshapeOutput = reshapeOp.getOutput();
393 auto reshapeInputType =
394 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
395 auto reshapeInputShape = reshapeInputType.getShape();
396 // want reshape N -> 1x1x...x1xNx1x...x1x1
397 if (!reshapeInputType || reshapeInputShape.size() != 1)
398 return std::nullopt;
399 auto reshapeOutputType =
400 llvm::cast<RankedTensorType>(reshapeOutput.getType());
401
402 // Instead of inserting a TransposeOp here, we check if we can fold it into
403 // the ReshapeOp. There is more complex cases where this is possible, and
404 // this check can be extended.
405
406 // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
407 auto shape = reshapeOutputType.getShape();
408 size_t ones = llvm::count(shape, 1);
409 // N == 1 and N != 1
410 if (ones != shape.size() - 1 &&
411 !(ones == shape.size() && reshapeInputShape[0] == 1))
412 return std::nullopt;
413
414 // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
415 llvm::SmallVector<int64_t> newShape;
416 if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
417 newShape)) {
418 // this mean shape is not constant
419 return std::nullopt;
420 }
421 ImplicitLocOpBuilder builder(reshapeOp.getLoc(), rewriter);
422 auto foldedReshape = ReshapeOp::create(
423 rewriter, reshapeOp.getLoc(),
424 RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
425 reshapeOutputType.getElementType()),
426 reshapeOp.getInput1(),
427 getTosaConstShape(builder, applyTOSAPermutation(llvm::ArrayRef(newShape),
428 hoistedPerms)));
429 return foldedReshape->getResult(0);
430}
431
432std::optional<Value> TosaReduceTransposes::buildMappedToValue(
433 ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
434 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
435 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
436 if (!denseAttr)
437 return std::nullopt;
438 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
439 if (!maybeNewDenseAttr.has_value())
440 return std::nullopt;
441 auto newDenseAttr = maybeNewDenseAttr.value();
442 auto newConstOp = ConstOp::create(rewriter, constOp.getLoc(),
443 newDenseAttr.getType(), newDenseAttr);
444 return newConstOp->getResult(0);
445}
446
447bool TosaReduceTransposes::convertDependentOps(
448 SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
449 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
450
451 for (Operation *op : dependentOps) {
452 if (!op || op->getNumResults() != 1)
453 return false;
454
455 Value priorValue = op->getResult(0);
456
457 // It's possible on a prior transposeOp we had the same dependency and
458 // already resolved it.
459 if (valuesMap.contains(priorValue))
460 continue;
461
462 // Keep converted ops close to the original.
463 rewriter.setInsertionPointAfter(op);
464
465 std::optional<Value> maybeValue =
466 llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
467 .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
468 return buildMappedToValue(transposeOp, valuesMap, rewriter,
469 hoistedPerms);
470 })
471 .Default([&](Operation *op) {
472 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
473 });
474
475 if (!maybeValue.has_value())
476 return false;
477
478 valuesMap[priorValue] = maybeValue.value();
479 }
480
481 return true;
482}
483
484bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
485 Operation *user, std::set<TransposeOp> &validTransposes,
486 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
487 &transposeInfo) {
488 return llvm::none_of(
489 transposeInfo,
490 [&validTransposes,
491 user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
492 const auto &[transposeOp, dependentOps] = info;
493 return validTransposes.count(transposeOp) &&
494 dependentOps.contains(user);
495 });
496}
497
498// Dependencies are valid for an operation if none of them occur outside
499// of the proper fan-in cones of the hoisted TransposeOp with the same perms
500// that we can replace. Described in more detail within.
501bool TosaReduceTransposes::dependenciesAreValid(
502 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
503 std::set<TransposeOp> &validTransposes,
504 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
505 &transposeInfo) {
506 for (Operation *op : dependentOps) {
507
508 // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
509 // This can be changed later if we find the memory impact is too high.
510 if (llvm::isa<ConstOp>(op))
511 continue;
512
513 for (OpOperand &use : op->getUses()) {
514 // Want the uses to be (1) contained in the dependentOps of other
515 // validTransposes, or (2) to be directly used in a TransposeOp with the
516 // same perms. For (2) it means the fan-in is a subset of our
517 // dependentOps, so it is also a validTranspose that will eventually be
518 // replaced.
519 Operation *user = use.getOwner();
520 if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
521 // Can later think about cases where transpose -> transpose
522 // or reshape -> transpose, where the transposes are not necessarily
523 // the same perms as the hoisted, if implementing a more general
524 // transform. These could be permitted.
525 if (!llvm::equal(perms, otherTranspose.getPerms()))
526 return false;
527 } else if (userNotContainedInValidTransposeDependencies(
528 user, validTransposes, transposeInfo)) {
529 return false;
530 }
531 }
532 }
533
534 return true;
535}
536
537// Getting the set of TransposeOp that we can replace without causing
538// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
539// dead code. This is done by iterating the set until convergence, since
540// if you are used outside your own fan-in cone, it's possible to be used
541// in another fan-in cone of a TransposeOp that is being replaced -- unless
542// we find that that one has a usage outside of it too.
543std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
544 ArrayRef<int32_t> perms,
545 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
546 &transposeInfo) {
547 // Initially, we assume they are all good to replace,
548 // and we whittle them down based on our criteria.
549 std::set<TransposeOp> ableToReplace;
550 for (const auto &[transposeOp, _] : transposeInfo)
551 ableToReplace.insert(transposeOp);
552
553 bool gotRid;
554 do {
555 gotRid = false;
556 for (const auto &[transposeOp, dependentOps] : transposeInfo) {
557 // We don't care about it. Already invalidated.
558 if (!ableToReplace.count(transposeOp))
559 continue;
560
561 // Check for validity.
562 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
563 transposeInfo)) {
564 ableToReplace.erase(transposeOp);
565 gotRid = true;
566 break;
567 }
568 }
569
570 } while (gotRid);
571
572 return ableToReplace;
573}
574
575void TosaReduceTransposes::runOnOperation() {
576 // We want to operate only within a single block.
577 if (!getOperation().getRegion().hasOneBlock())
578 return;
579
580 IRRewriter rewriter(&getContext());
581 // For each perms, maintain a mapping for converted ops, avoid duplication.
582 DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
583 // For each perms, we keep track of which TransposeOp are eligible
584 // for replacement alongside their dependentOps.
586 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
587 permsToTransposeInfo;
588
589 // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
590 // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
591 // since no guarantee of smallness.
592 std::vector<SmallVector<int32_t>> collectedPerms;
593
594 // This keeps track of the order across all eligible-for-replacement
595 // TransposeOp and their perms, a necessity for the final replacements.
596 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
597
598 // We want to reserve the space up front, since SmallVector stores some data
599 // internally and the ArrayRef can reference that, which we don't want to get
600 // invalidated.
601 size_t expectedMaxPerms = 0;
602 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
603 collectedPerms.reserve(expectedMaxPerms);
604
605 getOperation().walk([&](TransposeOp transposeOp) {
606 SetVector<Operation *> dependentOps;
607 collectedPerms.emplace_back();
608 SmallVector<int32_t> &perms = collectedPerms.back();
609
610 // Dynamic shapes are OK, but the incompatible ones will be rejected later.
611 auto input = transposeOp.getInput1();
612 auto output = transposeOp.getOutput();
613
614 // However, we don't support unranked tensors.
615 if (!llvm::isa<RankedTensorType>(input.getType()) ||
616 !llvm::isa<RankedTensorType>(output.getType()))
617 return;
618
619 llvm::append_range(perms, transposeOp.getPerms());
620
621 // We let --canonicalize deal with identity transpose.
622 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
623 return;
624
625 // Can fail if some set of basic invariants is not met that we want to
626 // perform our conversions.
627 if (!collectFanIn(input.getDefiningOp(), dependentOps))
628 return;
629
630 // Want to associate valuesMap for already converted of the same perms,
631 // since it's possible multiple hoisted transposes w/ different perms
632 // converge on an op, which would result in different transformations.
633 DenseMap<Value, Value> &valuesMap = permsToValues[perms];
634
635 // Attempt to perform the conversions and placements into IR
636 // without turning inserted code "live". Also fills out valuesMap.
637 // Fails if there is an intermediary we do not support.
638 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
639 // Some additional operations may have been inserted, but will be
640 // removed by dead code elimination.
641 return;
642
643 // This should not happen. If it does -- it's unexpected,
644 // so we fail the pass.
645 if (!valuesMap.contains(input))
647
648 // It's possible the types are not compatible (because of dynamic shapes),
649 // and in these cases, want to resolve dynamic shapes before running the
650 // pass.
651 if (output.getType() != valuesMap.at(input).getType())
652 return;
654 auto &transposeInfo = permsToTransposeInfo[perms];
655
656 // In general, we might also want to introduce "newDependentOps"
657 // if there are new usages that don't fall inside the original fan-ins
658 // (like the TransposeOp we insert for ReshapeOp),
659 // but in this case, that is specialized enough and overlaps
660 // with another direct-use TransposeOp case we need to cover anyway.
661 transposeInfo.emplace_back(transposeOp, dependentOps);
662
663 // This is for the final replacement across all transposes.
664 totalTransposeOrder.emplace(transposeOp, perms);
665 });
666
667 // We want to do a full fan-in analysis on a perms-level,
668 // since if we do it on a multi-perms level, and they share (due to a shared
669 // dependency on a Reshape) then we would also get duplicate ops.
670 // Const is special cased.
671 std::set<TransposeOp> ableToReplace;
672 for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
673 // Gives us back replacements that would never result in any duplicate
674 // operations being inserted by us in the IR (i.e, our goal is only to
675 // remove transposes, and not create a "new chain" to do so, but replace
676 // the existing chains).
677 // Ideally, --canonicalize is run before this pass, since it helps this
678 // analysis by removing dead code to allow more potentially acceptable
679 // transformations.
680 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
681 ableToReplace.insert(goodReplacementsForPerms.begin(),
682 goodReplacementsForPerms.end());
683 }
684
685 // We want to do replacement across all transposes
686 // in reverse order, due to invalidation of valuesMap mappings
687 // if we did it otherwise.
688 while (!totalTransposeOrder.empty()) {
689 auto [transposeOp, perms] = totalTransposeOrder.top();
690 totalTransposeOrder.pop();
691
692 if (ableToReplace.count(transposeOp) == 0)
693 continue;
694
695 auto &valuesMap = permsToValues[perms];
696 auto input = transposeOp.getInput1();
697
698 // The purpose of this reverse iteration
699 // is to avoid valuesMap invalidation. If it happens,
700 // something is wrong.
701 if (!valuesMap.contains(input))
702 return signalPassFailure();
703
704 rewriter.replaceOp(transposeOp, valuesMap.at(input));
705 }
706
707 // We can remove all dead code by going in reverse.
708 // This is because we would remove usages before we
709 // see the users.
710 getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
711 [&](Operation *op) {
712 if (isOpTriviallyDead(op))
713 rewriter.eraseOp(op);
714 });
715}
716
717} // namespace
b getContext())
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:846
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:225
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type getType() const
Return the type of this value.
Definition Value.h:105
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
This iterator enumerates elements in "reverse" order.
Definition Iterators.h:29