75 #include "llvm/ADT/TypeSwitch.h"
81 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
82 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
95 struct TosaReduceTransposes final
96 :
public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
97 void runOnOperation()
override;
123 buildMappedToValue(TransposeOp transposeOp,
130 buildMappedToValue(ReshapeOp reshapeOp,
149 std::set<TransposeOp> getGoodReplacements(
155 bool userNotContainedInValidTransposeDependencies(
156 Operation *user, std::set<TransposeOp> &validTransposes,
162 bool dependenciesAreValid(
164 std::set<TransposeOp> &validTransposes,
173 std::optional<DenseElementsAttr>
177 std::optional<DenseElementsAttr>
180 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.
getType());
182 int64_t rank = oldType.getRank();
186 if (rank <= 0 || oldType.getNumElements() <= 0) {
192 RankedTensorType newType =
200 if (!rawData.data()) {
220 size_t elementSize = oldType.getElementTypeBitWidth() / 8;
221 int64_t numElements = oldType.getNumElements();
224 const char *inputPtr = rawData.data();
225 char *outputPtr = outputBuffer.data();
228 int64_t rank = shape.size();
230 strides[rank - 1] = 1;
231 for (int64_t i = rank - 2; i >= 0; --i) {
232 strides[i] = strides[i + 1] * shape[i + 1];
241 auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242 int64_t tempDestIndex = destLinearIndex;
243 int64_t sourceLinearIndex = 0;
249 for (
auto j : llvm::seq<int64_t>(rank)) {
250 int64_t destCoord = tempDestIndex / outputStrides[
j];
251 tempDestIndex %= outputStrides[
j];
252 sourceLinearIndex += destCoord * inputStrides[perms[
j]];
255 return sourceLinearIndex;
258 for (
auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
259 int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);
263 std::memcpy(outputPtr + destLinearIndex * elementSize,
264 inputPtr + sourceLinearIndex * elementSize, elementSize);
273 bool TosaReduceTransposes::collectFanIn(
Operation *op,
279 if (!llvm::isa_and_present<tosa::TosaDialect>(op->
getDialect()))
283 if (collected.contains(op))
295 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
296 !llvm::isa<tosa::ConstOp>(op)) {
298 if (!llvm::isa<tosa::MulOp>(op) &&
304 if (llvm::isa<tosa::MulOp>(op) && operand == op->
getOperand(2)) {
308 if (!collectFanIn(operand.getDefiningOp(), collected))
314 collected.insert(op);
323 if (perms1.size() != perms2.size())
325 int32_t n = perms1.size();
326 for (int32_t i = 0; i < n; i++)
327 if (perms2[perms1[i]] != i)
335 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
339 (!llvm::isa<tosa::MulOp>(op) &&
346 if (valuesMap.contains(v)) {
347 operands.push_back(valuesMap.at(v));
348 }
else if (llvm::isa<tosa::MulOp>(op) && v == op->
getOperand(2)) {
350 operands.push_back(v);
376 resultType.getElementType()),
381 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
384 if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
386 return transposeOp.getInput1();
389 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
392 auto reshapeOutput = reshapeOp.getOutput();
393 auto reshapeInputType =
394 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
395 auto reshapeInputShape = reshapeInputType.getShape();
397 if (!reshapeInputType || reshapeInputShape.size() != 1)
399 auto reshapeOutputType =
400 llvm::cast<RankedTensorType>(reshapeOutput.getType());
407 auto shape = reshapeOutputType.getShape();
408 size_t ones = llvm::count(shape, 1);
410 if (ones != shape.size() - 1 &&
411 !(ones == shape.size() && reshapeInputShape[0] == 1))
422 auto foldedReshape = ReshapeOp::create(
423 rewriter, reshapeOp.getLoc(),
425 reshapeOutputType.getElementType()),
426 reshapeOp.getInput1(),
429 return foldedReshape->getResult(0);
432 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
435 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValues());
438 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
439 if (!maybeNewDenseAttr.has_value())
441 auto newDenseAttr = maybeNewDenseAttr.value();
442 auto newConstOp = ConstOp::create(rewriter, constOp.getLoc(),
443 newDenseAttr.getType(), newDenseAttr);
444 return newConstOp->getResult(0);
447 bool TosaReduceTransposes::convertDependentOps(
459 if (valuesMap.contains(priorValue))
465 std::optional<Value> maybeValue =
467 .Case<TransposeOp, ReshapeOp, ConstOp>([&](
auto transposeOp) {
468 return buildMappedToValue(transposeOp, valuesMap, rewriter,
472 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
475 if (!maybeValue.has_value())
478 valuesMap[priorValue] = maybeValue.value();
484 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
485 Operation *user, std::set<TransposeOp> &validTransposes,
488 return llvm::none_of(
492 const auto &[transposeOp, dependentOps] = info;
493 return validTransposes.count(transposeOp) &&
494 dependentOps.contains(user);
501 bool TosaReduceTransposes::dependenciesAreValid(
503 std::set<TransposeOp> &validTransposes,
510 if (llvm::isa<ConstOp>(op))
520 if (
auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
525 if (!llvm::equal(perms, otherTranspose.getPerms()))
527 }
else if (userNotContainedInValidTransposeDependencies(
528 user, validTransposes, transposeInfo)) {
543 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
549 std::set<TransposeOp> ableToReplace;
550 for (
const auto &[transposeOp, _] : transposeInfo)
551 ableToReplace.insert(transposeOp);
556 for (
const auto &[transposeOp, dependentOps] : transposeInfo) {
558 if (!ableToReplace.count(transposeOp))
562 if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
564 ableToReplace.
erase(transposeOp);
572 return ableToReplace;
575 void TosaReduceTransposes::runOnOperation() {
577 if (!getOperation().getRegion().hasOneBlock())
586 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
587 permsToTransposeInfo;
592 std::vector<SmallVector<int32_t>> collectedPerms;
596 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
601 size_t expectedMaxPerms = 0;
602 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
603 collectedPerms.reserve(expectedMaxPerms);
605 getOperation().walk([&](TransposeOp transposeOp) {
607 collectedPerms.emplace_back();
611 auto input = transposeOp.getInput1();
612 auto output = transposeOp.getOutput();
615 if (!llvm::isa<RankedTensorType>(input.
getType()) ||
616 !llvm::isa<RankedTensorType>(output.getType()))
619 llvm::append_range(perms, transposeOp.getPerms());
622 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
627 if (!collectFanIn(input.getDefiningOp(), dependentOps))
638 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
645 if (!valuesMap.contains(input))
646 return signalPassFailure();
651 if (output.getType() != valuesMap.at(input).getType())
654 auto &transposeInfo = permsToTransposeInfo[perms];
661 transposeInfo.emplace_back(transposeOp, dependentOps);
664 totalTransposeOrder.emplace(transposeOp, perms);
671 std::set<TransposeOp> ableToReplace;
672 for (
auto &[perms, transposeInfo] : permsToTransposeInfo) {
680 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
681 ableToReplace.insert(goodReplacementsForPerms.begin(),
682 goodReplacementsForPerms.end());
688 while (!totalTransposeOrder.empty()) {
689 auto [transposeOp, perms] = totalTransposeOrder.top();
690 totalTransposeOrder.pop();
692 if (ableToReplace.count(transposeOp) == 0)
695 auto &valuesMap = permsToValues[perms];
696 auto input = transposeOp.getInput1();
701 if (!valuesMap.contains(input))
702 return signalPassFailure();
704 rewriter.
replaceOp(transposeOp, valuesMap.at(input));
static MLIRContext * getContext(OpFoldResult val)
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.