76 #include "llvm/ADT/TypeSwitch.h"
83 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
84 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
97 struct TosaReduceTransposes final
98 :
public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
99 void runOnOperation()
override;
125 buildMappedToValue(TransposeOp transposeOp,
132 buildMappedToValue(ReshapeOp reshapeOp,
151 std::set<TransposeOp> getGoodReplacements(
157 bool userNotContainedInValidTransposeDependencies(
158 Operation *user, std::set<TransposeOp> &validTransposes,
164 bool dependenciesAreValid(
166 std::set<TransposeOp> &validTransposes,
175 std::optional<DenseElementsAttr>
179 std::optional<DenseElementsAttr>
182 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.
getType());
183 RankedTensorType newType =
185 oldType.getElementType());
186 size_t rank = oldType.getRank();
190 if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
213 originalInputStrides[rank - 1] = 1;
215 for (int64_t i = rank - 2; i >= 0; i--)
216 originalInputStrides[i] =
217 originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
223 newInputStrides.reserve(rank);
224 for (int32_t v : perms)
225 newInputStrides.push_back(originalInputStrides[v]);
231 for (
size_t i = 0; i < rank; i++)
232 boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
235 resultArray.reserve(inputArray.size());
237 std::function<void(int64_t,
238 SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
239 processTransposeDim = [&](
auto accumulatedIndex,
auto it) {
240 if (it == boundsAndStrides.end()) {
241 resultArray.push_back(inputArray[accumulatedIndex]);
245 for (int64_t i = 0; i < it->first; i++) {
246 int64_t
j = accumulatedIndex + i * it->second;
247 processTransposeDim(
j, it + 1);
251 processTransposeDim(0, boundsAndStrides.begin());
259 bool TosaReduceTransposes::collectFanIn(
Operation *op,
265 if (!llvm::isa_and_present<tosa::TosaDialect>(op->
getDialect()))
269 if (collected.contains(op))
281 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282 !llvm::isa<tosa::ConstOp>(op)) {
289 if (!collectFanIn(operand.getDefiningOp(), collected))
294 collected.insert(op);
303 if (perms1.size() != perms2.size())
305 int32_t n = perms1.size();
306 for (int32_t i = 0; i < n; i++)
307 if (perms2[perms1[i]] != i)
315 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
325 if (valuesMap.contains(v)) {
326 operands.push_back(valuesMap.at(v));
352 resultType.getElementType()),
357 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
361 if (failed(transposeOp.getConstantPerms(perms)) ||
362 !areInvolutionTransposes(hoistedPerms, perms))
364 return transposeOp.getInput1();
367 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
370 auto reshapeOutput = reshapeOp.getOutput();
371 auto reshapeInputType =
372 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
373 auto reshapeInputShape = reshapeInputType.getShape();
375 if (!reshapeInputType || reshapeInputShape.size() != 1)
377 auto reshapeOutputType =
378 llvm::cast<RankedTensorType>(reshapeOutput.getType());
385 auto shape = reshapeOutputType.getShape();
386 size_t ones = llvm::count(shape, 1);
388 if (ones != shape.size() - 1 &&
389 !(ones == shape.size() && reshapeInputShape[0] == 1))
393 auto foldedReshape = rewriter.
create<ReshapeOp>(
396 reshapeOutputType.getElementType()),
397 reshapeOp.getInput1(),
403 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
406 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
409 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
410 if (!maybeNewDenseAttr.has_value())
412 auto newDenseAttr = maybeNewDenseAttr.value();
413 auto newConstOp = rewriter.
create<ConstOp>(
414 constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
418 bool TosaReduceTransposes::convertDependentOps(
430 if (valuesMap.contains(priorValue))
436 std::optional<Value> maybeValue =
438 .Case<TransposeOp, ReshapeOp, ConstOp>([&](
auto transposeOp) {
439 return buildMappedToValue(transposeOp, valuesMap, rewriter,
443 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
446 if (!maybeValue.has_value())
449 valuesMap[priorValue] = maybeValue.value();
455 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
456 Operation *user, std::set<TransposeOp> &validTransposes,
459 return llvm::none_of(
463 const auto &[transposeOp, dependentOps] = info;
464 return validTransposes.count(transposeOp) &&
465 dependentOps.contains(user);
472 bool TosaReduceTransposes::dependenciesAreValid(
474 std::set<TransposeOp> &validTransposes,
481 if (llvm::isa<ConstOp>(op))
491 if (
auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
498 if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
499 !llvm::equal(perms, otherPerms))
501 }
else if (userNotContainedInValidTransposeDependencies(
502 user, validTransposes, transposeInfo)) {
517 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
523 std::set<TransposeOp> ableToReplace;
524 for (
const auto &[transposeOp, _] : transposeInfo)
525 ableToReplace.insert(transposeOp);
530 for (
const auto &[transposeOp, dependentOps] : transposeInfo) {
532 if (!ableToReplace.count(transposeOp))
536 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
538 ableToReplace.erase(transposeOp);
546 return ableToReplace;
549 void TosaReduceTransposes::runOnOperation() {
551 if (!getOperation().getRegion().hasOneBlock())
560 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
561 permsToTransposeInfo;
566 std::vector<SmallVector<int32_t>> collectedPerms;
570 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
575 size_t expectedMaxPerms = 0;
576 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
577 collectedPerms.reserve(expectedMaxPerms);
579 getOperation().walk([&](TransposeOp transposeOp) {
581 collectedPerms.emplace_back();
585 auto input = transposeOp.getInput1();
586 auto output = transposeOp.getOutput();
589 if (!llvm::isa<RankedTensorType>(input.
getType()) ||
590 !llvm::isa<RankedTensorType>(output.getType()))
594 if (failed(transposeOp.getConstantPerms(perms)))
598 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
603 if (!collectFanIn(input.getDefiningOp(), dependentOps))
614 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
621 if (!valuesMap.contains(input))
622 return signalPassFailure();
627 if (output.getType() != valuesMap.at(input).getType())
630 auto &transposeInfo = permsToTransposeInfo[perms];
637 transposeInfo.push_back({transposeOp, dependentOps});
640 totalTransposeOrder.push({transposeOp, perms});
647 std::set<TransposeOp> ableToReplace;
648 for (
auto &[perms, transposeInfo] : permsToTransposeInfo) {
656 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
657 ableToReplace.insert(goodReplacementsForPerms.begin(),
658 goodReplacementsForPerms.end());
664 while (!totalTransposeOrder.empty()) {
665 auto [transposeOp, perms] = totalTransposeOrder.top();
666 totalTransposeOrder.pop();
668 if (ableToReplace.count(transposeOp) == 0)
671 auto &valuesMap = permsToValues[perms];
672 auto input = transposeOp.getInput1();
677 if (!valuesMap.contains(input))
678 return signalPassFailure();
680 rewriter.
replaceOp(transposeOp, valuesMap.at(input));
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.