79 LogicalResult matchAndRewrite(GenericOp genericOp,
89 computeTransposeBroadcast(
AffineMap &map) {
97 for (int64_t i = 0; i < minorSize; ++i) {
98 auto expr = cast<AffineDimExpr>(map.
getResults()[i]);
99 minorResult.push_back(expr.getPosition());
104 llvm::sort(sortedResMap);
105 bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
106 sortedResMap.begin(), sortedResMap.end());
111 if (
j < minorSize && sortedResMap[
j] == i) {
127 permutation.resize(minorSize);
128 std::map<int64_t, int64_t> minorMap;
129 for (int64_t i = 0; i < minorSize; ++i)
130 minorMap.insert({sortedResMap[i], i});
134 for (int64_t i = 0; i < minorSize; ++i)
135 remappedResult[i] = minorMap[minorResult[i]];
138 for (
unsigned i = 0; i < minorSize; ++i) {
139 permutation[remappedResult[i]] = i;
145 LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
147 if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
148 op.isSingleYieldOp() || !op.isAllParallelLoops())
156 for (
auto &opOperand : op->getOpOperands()) {
157 auto map = op.getMatchingIndexingMap(&opOperand);
166 if (llvm::any_of(op->getOpOperands(), [](
OpOperand &oper) {
167 auto opType = cast<RankedTensorType>(oper.get().getType());
168 return ShapedType::isDynamicShape(opType.getShape());
172 auto outputShape = op.getStaticLoopRanges();
174 auto loc = op.getLoc();
175 bool isChanged =
false;
181 for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
182 auto &map = newMap[i];
183 auto inputRTType = cast<RankedTensorType>(newInitValues[i].
getType());
184 auto elType = inputRTType.getElementType();
190 auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
193 if (!permutation.empty()) {
198 transposedShape[i] = inputRTType.getShape()[permutation[i]];
201 tensor::EmptyOp::create(rewriter, loc, transposedShape, elType);
203 auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i],
204 emptyTensor, permutation);
205 newInitValues[i] = transposeOp->getResult(0);
210 if (!broadcastedDims.empty()) {
211 assert(broadcastedDims.size() &&
"should have non size broadcast");
212 Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape,
213 inputRTType.getElementType());
215 auto broadcastOp = linalg::BroadcastOp::create(
216 rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims);
218 newInitValues[i] = broadcastOp->getResult(0);
230 auto newOp = linalg::GenericOp::create(
233 op->getResultTypes(),
235 operandsRef.drop_front(op.getNumDpsInputs()),
237 op.getIteratorTypesArray());
238 newOp.getRegion().takeBody(op->getRegion(0));
239 rewriter.
replaceOp(op, newOp->getResults());
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
unsigned getNumInputs() const
bool isIdentity() const
Returns true if this affine map is an identity affine map.
AffineMap getMultiDimIdentityMap(unsigned rank)
This class represents an operand of an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns)
Add patterns to make explicit broadcasts and transforms in the input operands of a genericOp.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.