81 LogicalResult matchAndRewrite(GenericOp genericOp,
91 computeTransposeBroadcast(
AffineMap &map) {
99 for (int64_t i = 0; i < minorSize; ++i) {
100 auto expr = cast<AffineDimExpr>(map.
getResults()[i]);
101 minorResult.push_back(expr.getPosition());
106 std::sort(sortedResMap.begin(), sortedResMap.end());
107 bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
108 sortedResMap.begin(), sortedResMap.end());
113 if (
j < minorSize && sortedResMap[
j] == i) {
129 permutation.resize(minorSize);
130 std::map<int64_t, int64_t> minorMap;
131 for (int64_t i = 0; i < minorSize; ++i)
132 minorMap.insert({sortedResMap[i], i});
136 for (int64_t i = 0; i < minorSize; ++i)
137 remappedResult[i] = minorMap[minorResult[i]];
140 for (
unsigned i = 0; i < minorSize; ++i) {
141 permutation[remappedResult[i]] = i;
147 LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
149 if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
150 op.isSingleYieldOp() || !op.isAllParallelLoops())
158 for (
auto &opOperand : op->getOpOperands()) {
159 auto map = op.getMatchingIndexingMap(&opOperand);
168 if (llvm::any_of(op->getOpOperands(), [](
OpOperand &oper) {
169 auto opType = cast<RankedTensorType>(oper.get().getType());
170 return ShapedType::isDynamicShape(opType.getShape());
174 auto outputShape = op.getStaticLoopRanges();
176 auto loc = op.getLoc();
177 bool isChanged =
false;
183 for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184 auto &map = newMap[i];
185 auto inputRTType = cast<RankedTensorType>(newInitValues[i].
getType());
186 auto elType = inputRTType.getElementType();
192 auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
195 if (!permutation.empty()) {
200 transposedShape[i] = inputRTType.getShape()[permutation[i]];
203 rewriter.
create<tensor::EmptyOp>(loc, transposedShape, elType);
205 auto transposeOp = rewriter.
create<TransposeOp>(loc, newInitValues[i],
206 emptyTensor, permutation);
207 newInitValues[i] = transposeOp->
getResult(0);
212 if (!broadcastedDims.empty()) {
213 assert(broadcastedDims.size() &&
"should have non size broadcast");
214 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
215 loc, outputShape, inputRTType.getElementType());
217 auto broadcastOp = rewriter.
create<linalg::BroadcastOp>(
218 loc, newInitValues[i], emptyTensor, broadcastedDims);
220 newInitValues[i] = broadcastOp->
getResult(0);
230 auto newOp = rewriter.
create<linalg::GenericOp>(
232 op->getResultTypes(),
234 operandsRef.drop_front(op.getNumDpsInputs()),
236 op.getIteratorTypesArray());
239 rewriter.
replaceOp(op, newOp->getResults());
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineMap getMultiDimIdentityMap(unsigned rank)
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.