60  assert(invProducerResultIndexMap &&
 
   61         "expected producer result indexing map to be invertible");
 
   63  LinalgOp producer = cast<LinalgOp>(producerOpOperand->
getOwner());
 
   65  AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand);
 
   73  return t1.
compose(fusedConsumerArgIndexMap);
 
 
   80    GenericOp producer, GenericOp consumer,
 
   85  for (
auto &op : ops) {
 
   86    for (
auto &opOperand : op->getOpOperands()) {
 
   87      if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
 
   90      indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
 
   93  if (indexingMaps.empty()) {
 
   96    return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
 
  104             indexingMaps, producer.getContext())) != 
AffineMap();
 
 
  113    GenericOp producer, GenericOp consumer, 
OpOperand *fusedOperand) {
 
  114  llvm::SmallDenseSet<int> preservedProducerResults;
 
  118  opOperandsToIgnore.emplace_back(fusedOperand);
 
  120  for (
const auto &producerResult : llvm::enumerate(producer->getResults())) {
 
  121    auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
 
  122    opOperandsToIgnore.emplace_back(outputOperand);
 
  123    if (producer.payloadUsesValueFromOperand(outputOperand) ||
 
  125                                                  opOperandsToIgnore) ||
 
  126        llvm::any_of(producerResult.value().getUsers(), [&](
Operation *user) {
 
  127          return user != consumer.getOperation();
 
  129      preservedProducerResults.insert(producerResult.index());
 
  132      (
void)opOperandsToIgnore.pop_back_val();
 
  135  return preservedProducerResults;
 
 
  144  auto consumer = dyn_cast<GenericOp>(fusedOperand->
getOwner());
 
  147  if (!producer || !consumer)
 
  153  if (!producer.hasPureTensorSemantics() ||
 
  154      !isa<RankedTensorType>(fusedOperand->
get().
getType()))
 
  159  if (producer.getNumParallelLoops() != producer.getNumLoops())
 
  164  if (!consumer.isDpsInput(fusedOperand))
 
  169  AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand);
 
  170  if (consumerIndexMap.
getNumResults() != producer.getNumLoops())
 
  175  auto producerResult = cast<OpResult>(fusedOperand->
get());
 
  177      producer.getIndexingMapMatchingResult(producerResult);
 
  185  if ((consumer.getNumReductionLoops())) {
 
  186    BitVector coveredDims(consumer.getNumLoops(), 
false);
 
  188    auto addToCoveredDims = [&](
AffineMap map) {
 
  189      for (
auto result : map.getResults())
 
  190        if (
auto dimExpr = dyn_cast<AffineDimExpr>(
result))
 
  191          coveredDims[dimExpr.getPosition()] = 
true;
 
  195         llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) {
 
  196      Value operand = std::get<0>(pair);
 
  197      if (operand == fusedOperand->
get())
 
  199      AffineMap operandMap = std::get<1>(pair);
 
  200      addToCoveredDims(operandMap);
 
  203    for (
OpOperand *operand : producer.getDpsInputOperands()) {
 
  206              operand, producerResultIndexMap, consumerIndexMap);
 
  207      addToCoveredDims(newIndexingMap);
 
  209    if (!coveredDims.all())
 
 
  221    unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
 
  223  auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
 
  225  Block &producerBlock = producer->getRegion(0).
front();
 
  226  Block &consumerBlock = consumer->getRegion(0).
front();
 
  233  if (producer.hasIndexSemantics()) {
 
  235    unsigned numFusedOpLoops = fusedOp.getNumLoops();
 
  237    fusedIndices.reserve(numFusedOpLoops);
 
  238    llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
 
  239                    std::back_inserter(fusedIndices), [&](uint64_t dim) {
 
  240                      return IndexOp::create(rewriter, producer.getLoc(), dim);
 
  242    for (IndexOp indexOp :
 
  243         llvm::make_early_inc_range(producerBlock.
getOps<IndexOp>())) {
 
  244      Value newIndex = affine::AffineApplyOp::create(
 
  245          rewriter, producer.getLoc(),
 
  246          consumerToProducerLoopsMap.
getSubMap(indexOp.getDim()), fusedIndices);
 
  247      mapper.
map(indexOp.getResult(), newIndex);
 
  251  assert(consumer.isDpsInput(fusedOperand) &&
 
  252         "expected producer of input operand");
 
  256    mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
 
  263       producerBlock.
getArguments().take_front(producer.getNumDpsInputs()))
 
  264    mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
 
  269           .take_front(consumer.getNumDpsInputs())
 
  271    mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
 
  274  for (
const auto &bbArg : llvm::enumerate(
 
  275           producerBlock.
getArguments().take_back(producer.getNumDpsInits()))) {
 
  276    if (!preservedProducerResults.count(bbArg.index()))
 
  278    mapper.
map(bbArg.value(), fusedBlock->
addArgument(bbArg.value().getType(),
 
  279                                                      bbArg.value().getLoc()));
 
  284       consumerBlock.
getArguments().take_back(consumer.getNumDpsInits()))
 
  285    mapper.
map(bbArg, fusedBlock->
addArgument(bbArg.getType(), bbArg.getLoc()));
 
  290    if (!isa<IndexOp>(op))
 
  291      rewriter.
clone(op, mapper);
 
  295  auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.
getTerminator());
 
  296  unsigned producerResultNumber =
 
  297      cast<OpResult>(fusedOperand->
get()).getResultNumber();
 
  299      mapper.
lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
 
  303  if (
replacement == producerYieldOp.getOperand(producerResultNumber)) {
 
  304    if (
auto bb = dyn_cast<BlockArgument>(
replacement))
 
  305      assert(bb.getOwner() != &producerBlock &&
 
  306             "yielded block argument must have been mapped");
 
  308      assert(!producer->isAncestor(
replacement.getDefiningOp()) &&
 
  309             "yielded value must have been mapped");
 
  315    rewriter.
clone(op, mapper);
 
  319  auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.
getTerminator());
 
  321  fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
 
  322                           consumerYieldOp.getNumOperands());
 
  323  for (
const auto &producerYieldVal :
 
  324       llvm::enumerate(producerYieldOp.getOperands())) {
 
  325    if (preservedProducerResults.count(producerYieldVal.index()))
 
  326      fusedYieldValues.push_back(
 
  329  for (
auto consumerYieldVal : consumerYieldOp.getOperands())
 
  331  YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues);
 
  335         "Ill-formed GenericOp region");
 
 
  338FailureOr<mlir::linalg::ElementwiseOpFusionResult>
 
  342         "expected elementwise operation pre-conditions to pass");
 
  343  auto producerResult = cast<OpResult>(fusedOperand->
get());
 
  344  auto producer = cast<GenericOp>(producerResult.getOwner());
 
  345  auto consumer = cast<GenericOp>(fusedOperand->
getOwner());
 
  347  assert(consumer.isDpsInput(fusedOperand) &&
 
  348         "expected producer of input operand");
 
  351  llvm::SmallDenseSet<int> preservedProducerResults =
 
  359  fusedInputOperands.reserve(producer.getNumDpsInputs() +
 
  360                             consumer.getNumDpsInputs());
 
  361  fusedOutputOperands.reserve(preservedProducerResults.size() +
 
  362                              consumer.getNumDpsInits());
 
  363  fusedResultTypes.reserve(preservedProducerResults.size() +
 
  364                           consumer.getNumDpsInits());
 
  365  fusedIndexMaps.reserve(producer->getNumOperands() +
 
  366                         consumer->getNumOperands());
 
  369  auto consumerInputs = consumer.getDpsInputOperands();
 
  370  auto *it = llvm::find_if(consumerInputs, [&](
OpOperand *operand) {
 
  371    return operand == fusedOperand;
 
  373  assert(it != consumerInputs.end() && 
"expected to find the consumer operand");
 
  374  for (
OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
 
  375    fusedInputOperands.push_back(opOperand->get());
 
  376    fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
 
  380      producer.getIndexingMapMatchingResult(producerResult);
 
  381  for (
OpOperand *opOperand : producer.getDpsInputOperands()) {
 
  382    fusedInputOperands.push_back(opOperand->get());
 
  385        opOperand, producerResultIndexMap,
 
  386        consumer.getMatchingIndexingMap(fusedOperand));
 
  387    fusedIndexMaps.push_back(map);
 
  392       llvm::make_range(std::next(it), consumerInputs.end())) {
 
  393    fusedInputOperands.push_back(opOperand->get());
 
  394    fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
 
  398  for (
const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) {
 
  399    if (!preservedProducerResults.count(opOperand.index()))
 
  402    fusedOutputOperands.push_back(opOperand.value().get());
 
  404        &opOperand.value(), producerResultIndexMap,
 
  405        consumer.getMatchingIndexingMap(fusedOperand));
 
  406    fusedIndexMaps.push_back(map);
 
  407    fusedResultTypes.push_back(opOperand.value().get().getType());
 
  411  for (
OpOperand &opOperand : consumer.getDpsInitsMutable()) {
 
  412    fusedOutputOperands.push_back(opOperand.get());
 
  413    fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand));
 
  414    Type resultType = opOperand.get().getType();
 
  415    if (!isa<MemRefType>(resultType))
 
  416      fusedResultTypes.push_back(resultType);
 
  420  auto fusedOp = GenericOp::create(
 
  421      rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands,
 
  423      consumer.getIteratorTypes(),
 
  426  if (!fusedOp.getShapesToLoopsMap()) {
 
  432        fusedOp, 
"fused op failed loop bound computation check");
 
  438      consumer.getMatchingIndexingMap(fusedOperand);
 
  442  assert(invProducerResultIndexMap &&
 
  443         "expected producer result indexig map to be invertible");
 
  446      invProducerResultIndexMap.
compose(consumerResultIndexMap);
 
  449      rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
 
  450      consumer.getNumLoops(), preservedProducerResults);
 
  454  for (
auto [
index, producerResult] : llvm::enumerate(producer->getResults()))
 
  455    if (preservedProducerResults.count(
index))
 
  456      result.replacements[producerResult] = fusedOp->getResult(resultNum++);
 
  457  for (
auto consumerResult : consumer->getResults())
 
  458    result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
 
 
  469        controlFn(std::move(fun)) {}
 
  471  LogicalResult matchAndRewrite(GenericOp genericOp,
 
  474    for (
OpOperand &opOperand : genericOp->getOpOperands()) {
 
  477      if (!controlFn(&opOperand))
 
  480      Operation *producer = opOperand.get().getDefiningOp();
 
  483      FailureOr<ElementwiseOpFusionResult> fusionResult =
 
  485      if (failed(fusionResult))
 
  489      for (
auto [origVal, 
replacement] : fusionResult->replacements) {
 
  571      linalgOp.getIteratorTypesArray();
 
  572  AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
 
 
  573  return linalgOp.hasPureTensorSemantics() &&
 
  574         llvm::all_of(linalgOp.getIndexingMaps().getValue(),
 
  576                        return cast<AffineMapAttr>(attr)
 
  578                            .isProjectedPermutation();
 
 
  592  LogicalResult compute(LinalgOp linalgOp, 
OpOperand *fusableOpOperand,
 
  596  unsigned getOrigOpNumDims()
 const { 
return reassociation.size(); }
 
  597  unsigned getExpandedOpNumDims()
 const { 
return expandedOpNumDims; }
 
 
  599    return reassociation[i];
 
  602    return expandedShapeMap[i];
 
  604  ArrayRef<OpFoldResult> getOriginalShape()
 const { 
return originalLoopExtent; }
 
 
  615  unsigned expandedOpNumDims;
 
 
  619LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
 
  624  if (reassociationMaps.empty())
 
  626  AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
 
  628  OpBuilder::InsertionGuard g(rewriter);
 
  630  originalLoopExtent = llvm::map_to_vector(
 
  631      linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
 
  632      [](
Range r) { return r.size; });
 
  634  reassociation.clear();
 
  635  expandedShapeMap.clear();
 
  638  SmallVector<unsigned> numExpandedDims(fusedIndexMap.
getNumDims(), 1);
 
  639  expandedShapeMap.resize(fusedIndexMap.
getNumDims());
 
  640  for (
const auto &resultExpr : llvm::enumerate(fusedIndexMap.
getResults())) {
 
  641    unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
 
  642    AffineMap foldedDims = reassociationMaps[resultExpr.index()];
 
  644    ArrayRef<OpFoldResult> shape =
 
  645        expandedShape.slice(foldedDims.
getDimPosition(0), numExpandedDims[pos]);
 
  646    expandedShapeMap[pos].assign(shape.begin(), shape.end());
 
  649  for (
unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.
getNumDims()))
 
  650    if (expandedShapeMap[i].empty())
 
  651      expandedShapeMap[i] = {originalLoopExtent[i]};
 
  655  reassociation.reserve(fusedIndexMap.
getNumDims());
 
  656  for (
const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
 
  657    auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
 
  658    reassociation.emplace_back(seq.begin(), seq.end());
 
  659    sum += numFoldedDim.value();
 
  661  expandedOpNumDims = sum;
 
  669                           const ExpansionInfo &expansionInfo) {
 
  672    unsigned pos = cast<AffineDimExpr>(expr).getPosition();
 
  674        llvm::map_range(expansionInfo.getExpandedDims(pos), [&](
int64_t v) {
 
  675          return builder.getAffineDimExpr(static_cast<unsigned>(v));
 
  677    newExprs.append(expandedExprs.begin(), expandedExprs.end());
 
 
  686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
 
  688                        const ExpansionInfo &expansionInfo) {
 
  691    unsigned dim = cast<AffineDimExpr>(expr).getPosition();
 
  693        expansionInfo.getExpandedShapeOfDim(dim);
 
  694    expandedShape.append(dimExpansion.begin(), dimExpansion.end());
 
  697  std::tie(expandedStaticShape, std::ignore) =
 
  699  return {expandedShape, RankedTensorType::get(expandedStaticShape,
 
  700                                               originalType.getElementType())};
 
 
  709static SmallVector<ReassociationIndices>
 
  711                             const ExpansionInfo &expansionInfo) {
 
  713  unsigned numReshapeDims = 0;
 
  715    unsigned dim = cast<AffineDimExpr>(expr).getPosition();
 
  716    auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
 
  718        llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
 
  719    reassociation.emplace_back(std::move(
indices));
 
  720    numReshapeDims += numExpandedDims;
 
  722  return reassociation;
 
 
  732                                          const ExpansionInfo &expansionInfo) {
 
  734  for (IndexOp indexOp :
 
  735       llvm::make_early_inc_range(fusedRegion.
front().
getOps<IndexOp>())) {
 
  737        expansionInfo.getExpandedDims(indexOp.getDim());
 
  738    assert(!expandedDims.empty() && 
"expected valid expansion info");
 
  741    if (expandedDims.size() == 1 &&
 
  742        expandedDims.front() == (
int64_t)indexOp.getDim())
 
  749        expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
 
  751    expandedIndices.reserve(expandedDims.size() - 1);
 
  753        expandedDims.drop_front(), std::back_inserter(expandedIndices),
 
  754        [&](
int64_t dim) { return IndexOp::create(rewriter, loc, dim); });
 
  756        IndexOp::create(rewriter, loc, expandedDims.front()).getResult();
 
  757    for (
auto [expandedShape, expandedIndex] :
 
  758         llvm::zip(expandedDimsShape, expandedIndices)) {
 
  763          rewriter, indexOp.getLoc(), idx + 
acc * 
shape,
 
  768    rewriter.
replaceOp(indexOp, newIndexVal);
 
 
  790                                            TransposeOp transposeOp,
 
  792                                            ExpansionInfo &expansionInfo) {
 
  795    auto reassoc = expansionInfo.getExpandedDims(perm);
 
  797      newPerm.push_back(dim);
 
  800  return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput,
 
 
  811      expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
 
  813  for (
auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
 
  814    for (
auto j : expansionInfo.getExpandedDims(i))
 
  815      iteratorTypes[
j] = type;
 
  817  Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes,
 
  818                                       expandedOpOperands, outputs,
 
  819                                       expandedOpIndexingMaps, iteratorTypes);
 
  822  Region &originalRegion = linalgOp->getRegion(0);
 
 
  840                                   ExpansionInfo &expansionInfo) {
 
  843      .Case<TransposeOp>([&](TransposeOp transposeOp) {
 
  845                                         expandedOpOperands[0], outputs[0],
 
  848      .Case<FillOp, CopyOp>([&](
Operation *op) {
 
  849        return clone(rewriter, linalgOp, resultTypes,
 
  850                     llvm::to_vector(llvm::concat<Value>(
 
  851                         llvm::to_vector(expandedOpOperands),
 
  852                         llvm::to_vector(outputs))));
 
  856                                       expandedOpOperands, outputs,
 
  857                                       expansionInfo, expandedOpIndexingMaps);
 
 
  864static std::optional<SmallVector<Value>>
 
  869         "preconditions for fuse operation failed");
 
  875  if (
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
 
  879            rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
 
  882    expandedShape = expandingReshapeOp.getMixedOutputShape();
 
  883    reassociationIndices = expandingReshapeOp.getReassociationMaps();
 
  884    src = expandingReshapeOp.getSrc();
 
  886    auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
 
  887    if (!collapsingReshapeOp)
 
  891        rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
 
  892    reassociationIndices = collapsingReshapeOp.getReassociationMaps();
 
  893    src = collapsingReshapeOp.getSrc();
 
  896  ExpansionInfo expansionInfo;
 
  897  if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
 
  898                                   reassociationIndices, expandedShape,
 
  903      llvm::map_range(linalgOp.getIndexingMapsArray(), [&](
AffineMap m) {
 
  904        return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
 
  912  expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
 
  913  for (
OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
 
  914    if (opOperand == fusableOpOperand) {
 
  915      expandedOpOperands.push_back(src);
 
  918    if (
auto opOperandType =
 
  919            dyn_cast<RankedTensorType>(opOperand->get().getType())) {
 
  920      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
 
  922      RankedTensorType expandedOperandType;
 
  923      std::tie(expandedOperandShape, expandedOperandType) =
 
  925      if (expandedOperandType != opOperand->get().getType()) {
 
  929        if (failed(reshapeLikeShapesAreCompatible(
 
  930                [&](
const Twine &msg) {
 
  933                opOperandType.getShape(), expandedOperandType.getShape(),
 
  937        expandedOpOperands.push_back(tensor::ExpandShapeOp::create(
 
  938            rewriter, loc, expandedOperandType, opOperand->get(), reassociation,
 
  939            expandedOperandShape));
 
  943    expandedOpOperands.push_back(opOperand->get());
 
  947  for (
OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
 
  948    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
  949    auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
 
  951    RankedTensorType expandedOutputType;
 
  952    std::tie(expandedOutputShape, expandedOutputType) =
 
  954    if (expandedOutputType != opOperand.get().getType()) {
 
  957      if (failed(reshapeLikeShapesAreCompatible(
 
  958              [&](
const Twine &msg) {
 
  961              opOperandType.getShape(), expandedOutputType.getShape(),
 
  965      outputs.push_back(tensor::ExpandShapeOp::create(
 
  966          rewriter, loc, expandedOutputType, opOperand.get(), reassociation,
 
  967          expandedOutputShape));
 
  969      outputs.push_back(opOperand.get());
 
  976                       outputs, expandedOpIndexingMaps, expansionInfo);
 
  980  for (
OpResult opResult : linalgOp->getOpResults()) {
 
  981    int64_t resultNumber = opResult.getResultNumber();
 
  982    if (resultTypes[resultNumber] != opResult.getType()) {
 
  985              linalgOp.getMatchingIndexingMap(
 
  986                  linalgOp.getDpsInitOperand(resultNumber)),
 
  988      resultVals.push_back(tensor::CollapseShapeOp::create(
 
  989          rewriter, linalgOp.getLoc(), opResult.getType(),
 
  990          fusedOp->
getResult(resultNumber), reassociation));
 
  992      resultVals.push_back(fusedOp->
getResult(resultNumber));
 
 
 1004class FoldWithProducerReshapeOpByExpansion
 
 1005    : 
public OpInterfaceRewritePattern<LinalgOp> {
 
 1007  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
 
 1009                                       PatternBenefit benefit = 1)
 
 1010      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
 
 1011        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1013  LogicalResult matchAndRewrite(LinalgOp linalgOp,
 
 1014                                PatternRewriter &rewriter)
 const override {
 
 1015    for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
 
 1016      tensor::CollapseShapeOp reshapeOp =
 
 1017          opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
 
 1024          (!controlFoldingReshapes(opOperand)))
 
 1027      std::optional<SmallVector<Value>> replacementValues =
 
 1029      if (!replacementValues)
 
 1031      rewriter.
replaceOp(linalgOp, *replacementValues);
 
 1041class FoldPadWithProducerReshapeOpByExpansion
 
 1042    : 
public OpRewritePattern<tensor::PadOp> {
 
 1044  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
 
 1046                                          PatternBenefit benefit = 1)
 
 1047      : OpRewritePattern<tensor::PadOp>(context, benefit),
 
 1048        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1050  LogicalResult matchAndRewrite(tensor::PadOp padOp,
 
 1051                                PatternRewriter &rewriter)
 const override {
 
 1052    tensor::CollapseShapeOp reshapeOp =
 
 1053        padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
 
 1056    if (!reshapeOp->hasOneUse())
 
 1059    if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
 
 1061                                         "fusion blocked by control function");
 
 1064    ArrayRef<int64_t> low = padOp.getStaticLow();
 
 1065    ArrayRef<int64_t> high = padOp.getStaticHigh();
 
 1066    SmallVector<ReassociationIndices> reassociations =
 
 1067        reshapeOp.getReassociationIndices();
 
 1069    for (
auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
 
 1070      if (reInd.size() != 1 && (l != 0 || h != 0))
 
 1074    SmallVector<OpFoldResult> newLow, newHigh;
 
 1075    RankedTensorType expandedType = reshapeOp.getSrcType();
 
 1076    RankedTensorType paddedType = padOp.getResultType();
 
 1077    SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
 
 1078    for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
 
 1079      if (reInd.size() == 1) {
 
 1080        expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
 
 1082      for (
size_t i = 0; i < reInd.size(); ++i) {
 
 1083        newLow.push_back(padOp.getMixedLowPad()[idx]);
 
 1084        newHigh.push_back(padOp.getMixedHighPad()[idx]);
 
 1088    Location loc = padOp->getLoc();
 
 1089    RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
 
 1090    auto newPadOp = tensor::PadOp::create(
 
 1091        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
 
 1092        padOp.getConstantPaddingValue(), padOp.getNofold());
 
 1095        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
 
 1106struct FoldReshapeWithGenericOpByExpansion
 
 1107    : 
public OpRewritePattern<tensor::ExpandShapeOp> {
 
 1109  FoldReshapeWithGenericOpByExpansion(MLIRContext *context,
 
 1111                                      PatternBenefit benefit = 1)
 
 1112      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
 
 1113        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1115  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
 
 1116                                PatternRewriter &rewriter)
 const override {
 
 1118    auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
 
 1119    if (!producerResult) {
 
 1121                                         "source not produced by an operation");
 
 1124    auto producer = dyn_cast<LinalgOp>(producerResult.getOwner());
 
 1127                                         "producer not a generic op");
 
 1132            producer.getDpsInitOperand(producerResult.getResultNumber()))) {
 
 1134          reshapeOp, 
"failed preconditions of fusion with producer generic op");
 
 1137    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
 
 1139                                         "fusion blocked by control function");
 
 1142    std::optional<SmallVector<Value>> replacementValues =
 
 1144            producer, reshapeOp,
 
 1145            producer.getDpsInitOperand(producerResult.getResultNumber()),
 
 1147    if (!replacementValues) {
 
 1149                                         "fusion by expansion failed");
 
 1156    Value reshapeReplacement =
 
 1157        (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
 
 1158                                 .getResultNumber()];
 
 1159    if (
auto collapseOp =
 
 1160            reshapeReplacement.
getDefiningOp<tensor::CollapseShapeOp>()) {
 
 1161      reshapeReplacement = collapseOp.getSrc();
 
 1163    rewriter.
replaceOp(reshapeOp, reshapeReplacement);
 
 1164    rewriter.
replaceOp(producer, *replacementValues);
 
 1186         "expected projected permutation");
 
 1189      llvm::map_range(rangeReassociation, [&](
int64_t pos) -> 
int64_t {
 
 1190        return cast<AffineDimExpr>(indexingMap.
getResults()[pos]).getPosition();
 
 1194  return domainReassociation;
 
 
 1202  assert(!dimSequence.empty() &&
 
 1203         "expected non-empty list for dimension sequence");
 
 1205         "expected indexing map to be projected permutation");
 
 1207  llvm::SmallDenseSet<unsigned, 4> sequenceElements;
 
 1208  sequenceElements.insert_range(dimSequence);
 
 1210  unsigned dimSequenceStart = dimSequence[0];
 
 1211  for (
const auto &expr : enumerate(indexingMap.
getResults())) {
 
 1212    unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
 
 1214    if (dimInMapStart == dimSequenceStart) {
 
 1215      if (expr.index() + dimSequence.size() > indexingMap.
getNumResults())
 
 1218      for (
const auto &dimInSequence : enumerate(dimSequence)) {
 
 1220            cast<AffineDimExpr>(
 
 1221                indexingMap.
getResult(expr.index() + dimInSequence.index()))
 
 1223        if (dimInMap != dimInSequence.value())
 
 1234    if (sequenceElements.count(dimInMapStart))
 
 
 1243  return llvm::all_of(maps, [&](
AffineMap map) {
 
 
 1300  if (!genericOp.hasPureTensorSemantics())
 
 1303  if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](
AffineMap map) {
 
 1304        return map.isProjectedPermutation();
 
 1311  genericOp.getReductionDims(reductionDims);
 
 1313  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
 
 1314  AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
 
 1315  auto iteratorTypes = genericOp.getIteratorTypesArray();
 
 1318    assert(!foldedRangeDims.empty() && 
"unexpected empty reassociation");
 
 1321    if (foldedRangeDims.size() == 1)
 
 1329    if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
 
 1330          return processedIterationDims.count(dim);
 
 1335    utils::IteratorType startIteratorType =
 
 1336        iteratorTypes[foldedIterationSpaceDims[0]];
 
 1340    if (llvm::any_of(foldedIterationSpaceDims, [&](
int64_t dim) {
 
 1341          return iteratorTypes[dim] != startIteratorType;
 
 1350      bool isContiguous = 
false;
 
 1351      for (
const auto &startDim : llvm::enumerate(reductionDims)) {
 
 1353        if (startDim.value() != foldedIterationSpaceDims[0])
 
 1357        if (startDim.index() + foldedIterationSpaceDims.size() >
 
 1358            reductionDims.size())
 
 1361        isContiguous = 
true;
 
 1362        for (
const auto &foldedDim :
 
 1363             llvm::enumerate(foldedIterationSpaceDims)) {
 
 1364          if (reductionDims[foldedDim.index() + startDim.index()] !=
 
 1365              foldedDim.value()) {
 
 1366            isContiguous = 
false;
 
 1377    if (llvm::any_of(genericOp.getIndexingMapsArray(),
 
 1379                       return !isDimSequencePreserved(indexingMap,
 
 1380                                                      foldedIterationSpaceDims);
 
 1384    processedIterationDims.insert_range(foldedIterationSpaceDims);
 
 1385    iterationSpaceReassociation.emplace_back(
 
 1386        std::move(foldedIterationSpaceDims));
 
 1389  return iterationSpaceReassociation;
 
 
 1394class CollapsingInfo {
 
 1396  LogicalResult 
initialize(
unsigned origNumLoops,
 
 1397                           ArrayRef<ReassociationIndices> foldedIterationDims) {
 
 1398    llvm::SmallDenseSet<int64_t, 4> processedDims;
 
 1401      if (foldedIterationDim.empty())
 
 1405      for (
auto dim : foldedIterationDim) {
 
 1406        if (dim >= origNumLoops)
 
 1408        if (processedDims.count(dim))
 
 1410        processedDims.insert(dim);
 
 1412      collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
 
 1413                                                   foldedIterationDim.end());
 
 1415    if (processedDims.size() > origNumLoops)
 
 1420    for (
auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
 
 1421      if (processedDims.count(dim))
 
 1426    llvm::sort(collapsedOpToOrigOpIterationDim,
 
 1430    origOpToCollapsedOpIterationDim.resize(origNumLoops);
 
 1431    for (
const auto &foldedDims :
 
 1432         llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
 
 1433      for (
const auto &dim : 
enumerate(foldedDims.value()))
 
 1434        origOpToCollapsedOpIterationDim[dim.value()] =
 
 1435            std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
 
 1442    return collapsedOpToOrigOpIterationDim;
 
 1465  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping()
 const {
 
 1466    return origOpToCollapsedOpIterationDim;
 
 1470  unsigned getCollapsedOpIterationRank()
 const {
 
 1471    return collapsedOpToOrigOpIterationDim.size();
 
 1477  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
 
 1481  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
 
 1487static SmallVector<utils::IteratorType>
 
 1488getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
 
 1489                            const CollapsingInfo &collapsingInfo) {
 
 1490  SmallVector<utils::IteratorType> collapsedIteratorTypes;
 
 1492       collapsingInfo.getCollapsedOpToOrigOpMapping()) {
 
 1493    assert(!foldedIterDims.empty() &&
 
 1494           "reassociation indices expected to have non-empty sets");
 
 1498    collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]);
 
 1500  return collapsedIteratorTypes;
 
 1506getCollapsedOpIndexingMap(AffineMap indexingMap,
 
 1507                          const CollapsingInfo &collapsingInfo) {
 
 1508  MLIRContext *context = indexingMap.
getContext();
 
 1510         "expected indexing map to be projected permutation");
 
 1511  SmallVector<AffineExpr> resultExprs;
 
 1512  auto origOpToCollapsedOpMapping =
 
 1513      collapsingInfo.getOrigOpToCollapsedOpMapping();
 
 1515    unsigned dim = cast<AffineDimExpr>(expr).getPosition();
 
 1517    if (origOpToCollapsedOpMapping[dim].second != 0)
 
 1521    resultExprs.push_back(
 
 1524  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
 
 1525                        resultExprs, context);
 
 1530static SmallVector<ReassociationIndices>
 
 1531getOperandReassociation(AffineMap indexingMap,
 
 1532                        const CollapsingInfo &collapsingInfo) {
 
 1533  unsigned counter = 0;
 
 1534  SmallVector<ReassociationIndices> operandReassociation;
 
 1535  auto origOpToCollapsedOpMapping =
 
 1536      collapsingInfo.getOrigOpToCollapsedOpMapping();
 
 1537  auto collapsedOpToOrigOpMapping =
 
 1538      collapsingInfo.getCollapsedOpToOrigOpMapping();
 
 1541        cast<AffineDimExpr>(indexingMap.
getResult(counter)).getPosition();
 
 1545    unsigned numFoldedDims =
 
 1546        collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
 
 1548    if (origOpToCollapsedOpMapping[dim].second == 0) {
 
 1549      auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
 
 1550      operandReassociation.emplace_back(range.begin(), range.end());
 
 1552    counter += numFoldedDims;
 
 1554  return operandReassociation;
 
 1558static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 
 1559                                   OpOperand *opOperand,
 
 1560                                   const CollapsingInfo &collapsingInfo,
 
 1561                                   OpBuilder &builder) {
 
 1562  AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
 
 1563  SmallVector<ReassociationIndices> operandReassociation =
 
 1564      getOperandReassociation(indexingMap, collapsingInfo);
 
 1569  Value operand = opOperand->
get();
 
 1570  if (operandReassociation.size() == indexingMap.
getNumResults())
 
 1574  if (isa<MemRefType>(operand.
getType())) {
 
 1575    return memref::CollapseShapeOp::create(builder, loc, operand,
 
 1576                                           operandReassociation)
 
 1579  return tensor::CollapseShapeOp::create(builder, loc, operand,
 
 1580                                         operandReassociation)
 
 1586static void generateCollapsedIndexingRegion(
 
 1587    Location loc, 
Block *block, 
const CollapsingInfo &collapsingInfo,
 
 1588    ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
 
 1589  OpBuilder::InsertionGuard g(rewriter);
 
 1593  auto indexOps = llvm::to_vector(block->
getOps<linalg::IndexOp>());
 
 1602  llvm::DenseMap<unsigned, Value> indexReplacementVals;
 
 1603  for (
auto foldedDims :
 
 1604       enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
 
 1607        linalg::IndexOp::create(rewriter, loc, foldedDims.index());
 
 1608    for (
auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
 
 1611      indexReplacementVals[dim] =
 
 1612          rewriter.
createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
 
 1614          rewriter.
createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
 
 1616    indexReplacementVals[foldedDims.value().front()] = newIndexVal;
 
 1619  for (
auto indexOp : indexOps) {
 
 1620    auto dim = indexOp.getDim();
 
 1621    rewriter.
replaceOp(indexOp, indexReplacementVals[dim]);
 
 1625static void collapseOperandsAndResults(LinalgOp op,
 
 1626                                       const CollapsingInfo &collapsingInfo,
 
 1627                                       RewriterBase &rewriter,
 
 1628                                       SmallVectorImpl<Value> &inputOperands,
 
 1629                                       SmallVectorImpl<Value> &outputOperands,
 
 1630                                       SmallVectorImpl<Type> &resultTypes) {
 
 1631  Location loc = op->getLoc();
 
 1633      llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
 
 1634        return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
 
 1639  resultTypes.reserve(op.getNumDpsInits());
 
 1640  outputOperands.reserve(op.getNumDpsInits());
 
 1641  for (OpOperand &output : op.getDpsInitsMutable()) {
 
 1643        getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter);
 
 1644    outputOperands.push_back(newOutput);
 
 1647    if (!op.hasPureBufferSemantics())
 
 1648      resultTypes.push_back(newOutput.
getType());
 
 1653template <
typename OpTy>
 
 1654static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
 
 1655                               const CollapsingInfo &collapsingInfo) {
 
 1662LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
 
 1663                                      const CollapsingInfo &collapsingInfo) {
 
 1664  SmallVector<Value> inputOperands, outputOperands;
 
 1665  SmallVector<Type> resultTypes;
 
 1666  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
 
 1667                             outputOperands, resultTypes);
 
 1670      rewriter, origOp, resultTypes,
 
 1671      llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
 
 1676GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
 
 1678                                        const CollapsingInfo &collapsingInfo) {
 
 1679  SmallVector<Value> inputOperands, outputOperands;
 
 1680  SmallVector<Type> resultTypes;
 
 1681  collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
 
 1682                             outputOperands, resultTypes);
 
 1683  SmallVector<AffineMap> indexingMaps(
 
 1684      llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
 
 1685        return getCollapsedOpIndexingMap(map, collapsingInfo);
 
 1688  SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
 
 1689      origOp.getIteratorTypesArray(), collapsingInfo));
 
 1691  GenericOp collapsedOp = linalg::GenericOp::create(
 
 1692      rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands,
 
 1693      indexingMaps, iteratorTypes,
 
 1694      [](OpBuilder &builder, Location loc, 
ValueRange args) {});
 
 1695  Block *origOpBlock = &origOp->getRegion(0).front();
 
 1696  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
 
 1697  rewriter.
mergeBlocks(origOpBlock, collapsedOpBlock,
 
 1702static LinalgOp createCollapsedOp(LinalgOp op,
 
 1703                                  const CollapsingInfo &collapsingInfo,
 
 1704                                  RewriterBase &rewriter) {
 
 1705  if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
 
 1706    return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
 
 1708    return cloneToCollapsedOp(rewriter, op, collapsingInfo);
 
 1714    LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
 
 1715    RewriterBase &rewriter) {
 
 1717  if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
 
 1719        return foldedDims.size() <= 1;
 
 1723  CollapsingInfo collapsingInfo;
 
 1725          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
 
 1727        op, 
"illegal to collapse specified dimensions");
 
 1730  bool hasPureBufferSemantics = op.hasPureBufferSemantics();
 
 1731  if (hasPureBufferSemantics &&
 
 1732      !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> 
bool {
 
 1733        MemRefType memRefToCollapse =
 
 1734            dyn_cast<MemRefType>(opOperand.get().getType());
 
 1735        if (!memRefToCollapse)
 
 1738        AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
 
 1739        SmallVector<ReassociationIndices> operandReassociation =
 
 1740            getOperandReassociation(indexingMap, collapsingInfo);
 
 1741        return memref::CollapseShapeOp::isGuaranteedCollapsible(
 
 1742            memRefToCollapse, operandReassociation);
 
 1745                                       "memref is not guaranteed collapsible");
 
 1748  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
 
 1749  auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
 
 1750    if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
 
 1751      return cast<IntegerAttr>(attr).getInt() == value;
 
 1754           actual.getSExtValue() == value;
 
 1756  if (!llvm::all_of(loopRanges, [&](Range range) {
 
 1757        return opFoldIsConstantValue(range.
offset, 0) &&
 
 1758               opFoldIsConstantValue(range.
stride, 1);
 
 1761        op, 
"expected all loop ranges to have zero start and unit stride");
 
 1764  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
 1766  Location loc = op->getLoc();
 
 1767  SmallVector<OpFoldResult> loopBound =
 
 1768      llvm::map_to_vector(loopRanges, [](Range range) { 
return range.
size; });
 
 1770  if (collapsedOp.hasIndexSemantics()) {
 
 1772    OpBuilder::InsertionGuard g(rewriter);
 
 1774    generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
 
 1775                                    collapsingInfo, loopBound, rewriter);
 
 1780  SmallVector<Value> results;
 
 1781  for (
const auto &originalResult : llvm::enumerate(op->getResults())) {
 
 1782    Value collapsedOpResult = collapsedOp->getResult(originalResult.index());
 
 1783    auto originalResultType =
 
 1784        cast<ShapedType>(originalResult.value().getType());
 
 1785    auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.
getType());
 
 1786    if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
 
 1787      AffineMap indexingMap =
 
 1788          op.getIndexingMapMatchingResult(originalResult.value());
 
 1789      SmallVector<ReassociationIndices> reassociation =
 
 1790          getOperandReassociation(indexingMap, collapsingInfo);
 
 1793          "Expected indexing map to be a projected permutation for collapsing");
 
 1794      SmallVector<OpFoldResult> resultShape =
 
 1797      if (isa<MemRefType>(collapsedOpResult.
getType())) {
 
 1798        MemRefType expandShapeResultType = MemRefType::get(
 
 1799            originalResultType.getShape(), originalResultType.getElementType());
 
 1800        result = memref::ExpandShapeOp::create(
 
 1801            rewriter, loc, expandShapeResultType, collapsedOpResult,
 
 1802            reassociation, resultShape);
 
 1804        result = tensor::ExpandShapeOp::create(
 
 1805            rewriter, loc, originalResultType, collapsedOpResult, reassociation,
 
 1808      results.push_back(
result);
 
 1810      results.push_back(collapsedOpResult);
 
 1813  return CollapseResult{results, collapsedOp};
 
 1820class FoldWithProducerReshapeOpByCollapsing
 
 1821    : 
public OpRewritePattern<GenericOp> {
 
 1824  FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
 
 1826                                        PatternBenefit benefit = 1)
 
 1827      : OpRewritePattern<GenericOp>(context, benefit),
 
 1828        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1830  LogicalResult matchAndRewrite(GenericOp genericOp,
 
 1831                                PatternRewriter &rewriter)
 const override {
 
 1832    for (OpOperand &opOperand : genericOp->getOpOperands()) {
 
 1833      tensor::ExpandShapeOp reshapeOp =
 
 1838      SmallVector<ReassociationIndices> collapsableIterationDims =
 
 1840                                           reshapeOp.getReassociationIndices());
 
 1841      if (collapsableIterationDims.empty() ||
 
 1842          !controlFoldingReshapes(&opOperand)) {
 
 1847          genericOp, collapsableIterationDims, rewriter);
 
 1848      if (!collapseResult) {
 
 1850            genericOp, 
"failed to do the fusion by collapsing transformation");
 
 1853      rewriter.
replaceOp(genericOp, collapseResult->results);
 
 1865struct FoldReshapeWithGenericOpByCollapsing
 
 1866    : 
public OpRewritePattern<tensor::CollapseShapeOp> {
 
 1868  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
 
 1870                                       PatternBenefit benefit = 1)
 
 1871      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
 
 1872        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1874  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
 
 1875                                PatternRewriter &rewriter)
 const override {
 
 1878    auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
 
 1879    if (!producerResult) {
 
 1881                                         "source not produced by an operation");
 
 1885    auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
 
 1888                                         "producer not a generic op");
 
 1891    SmallVector<ReassociationIndices> collapsableIterationDims =
 
 1894            producer.getDpsInitOperand(producerResult.getResultNumber()),
 
 1895            reshapeOp.getReassociationIndices());
 
 1896    if (collapsableIterationDims.empty()) {
 
 1898          reshapeOp, 
"failed preconditions of fusion with producer generic op");
 
 1901    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
 
 1903                                         "fusion blocked by control function");
 
 1909    std::optional<CollapseResult> collapseResult =
 
 1911    if (!collapseResult) {
 
 1913          producer, 
"failed to do the fusion by collapsing transformation");
 
 1916    rewriter.
replaceOp(producer, collapseResult->results);
 
 1924class FoldPadWithProducerReshapeOpByCollapsing
 
 1925    : 
public OpRewritePattern<tensor::PadOp> {
 
 1927  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
 
 1929                                           PatternBenefit benefit = 1)
 
 1930      : OpRewritePattern<tensor::PadOp>(context, benefit),
 
 1931        controlFoldingReshapes(std::move(foldReshapes)) {}
 
 1933  LogicalResult matchAndRewrite(tensor::PadOp padOp,
 
 1934                                PatternRewriter &rewriter)
 const override {
 
 1935    tensor::ExpandShapeOp reshapeOp =
 
 1936        padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
 
 1939    if (!reshapeOp->hasOneUse())
 
 1942    if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
 
 1944                                         "fusion blocked by control function");
 
 1947    ArrayRef<int64_t> low = padOp.getStaticLow();
 
 1948    ArrayRef<int64_t> high = padOp.getStaticHigh();
 
 1949    SmallVector<ReassociationIndices> reassociations =
 
 1950        reshapeOp.getReassociationIndices();
 
 1952    for (
auto reInd : reassociations) {
 
 1953      if (reInd.size() == 1)
 
 1955      if (llvm::any_of(reInd, [&](int64_t ind) {
 
 1956            return low[ind] != 0 || high[ind] != 0;
 
 1962    SmallVector<OpFoldResult> newLow, newHigh;
 
 1963    RankedTensorType collapsedType = reshapeOp.getSrcType();
 
 1964    RankedTensorType paddedType = padOp.getResultType();
 
 1965    SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
 
 1966    SmallVector<OpFoldResult> expandedPaddedSizes(
 
 1968                       reshapeOp.getOutputShape(), rewriter));
 
 1969    AffineExpr d0, d1, d2;
 
 1972    Location loc = reshapeOp->getLoc();
 
 1973    for (
auto [idx, reInd] : llvm::enumerate(reassociations)) {
 
 1974      OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
 
 1975      OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
 
 1976      if (reInd.size() == 1) {
 
 1977        collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
 
 1979            rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
 
 1980        expandedPaddedSizes[reInd[0]] = paddedSize;
 
 1982      newLow.push_back(l);
 
 1983      newHigh.push_back(h);
 
 1986    RankedTensorType collapsedPaddedType =
 
 1987        paddedType.clone(collapsedPaddedShape);
 
 1988    auto newPadOp = tensor::PadOp::create(
 
 1989        rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
 
 1990        padOp.getConstantPaddingValue(), padOp.getNofold());
 
 1993        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
 
 1994        expandedPaddedSizes);
 
 2004template <
typename LinalgType>
 
 2005class CollapseLinalgDimensions : 
public OpRewritePattern<LinalgType> {
 
 2007  CollapseLinalgDimensions(MLIRContext *context,
 
 2009                           PatternBenefit benefit = 1)
 
 2010      : OpRewritePattern<LinalgType>(context, benefit),
 
 2011        controlCollapseDimension(std::move(collapseDimensions)) {}
 
 2013  LogicalResult matchAndRewrite(LinalgType op,
 
 2014                                PatternRewriter &rewriter)
 const override {
 
 2015    SmallVector<ReassociationIndices> collapsableIterationDims =
 
 2016        controlCollapseDimension(op);
 
 2017    if (collapsableIterationDims.empty())
 
 2022                                  collapsableIterationDims)) {
 
 2024          op, 
"specified dimensions cannot be collapsed");
 
 2027    std::optional<CollapseResult> collapseResult =
 
 2029    if (!collapseResult) {
 
 2032    rewriter.
replaceOp(op, collapseResult->results);
 
 2049class FoldScalarOrSplatConstant : 
public OpRewritePattern<GenericOp> {
 
 2051  FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1)
 
 2052      : OpRewritePattern<GenericOp>(context, benefit) {}
 
 2054  LogicalResult matchAndRewrite(GenericOp genericOp,
 
 2055                                PatternRewriter &rewriter)
 const override {
 
 2056    if (!genericOp.hasPureTensorSemantics())
 
 2058    for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
 
 2060      TypedAttr constantAttr;
 
 2061      auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> 
bool {
 
 2063          DenseElementsAttr splatAttr;
 
 2066              splatAttr.
getType().getElementType().isIntOrFloat()) {
 
 2072          IntegerAttr intAttr;
 
 2074            constantAttr = intAttr;
 
 2079          FloatAttr floatAttr;
 
 2081            constantAttr = floatAttr;
 
 2088      auto resultValue = dyn_cast<OpResult>(opOperand->
get());
 
 2089      if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
 
 2095      SmallVector<AffineMap> fusedIndexMaps;
 
 2096      SmallVector<Value> fusedOperands;
 
 2097      SmallVector<Location> fusedLocs{genericOp.getLoc()};
 
 2098      fusedIndexMaps.reserve(genericOp->getNumOperands());
 
 2099      fusedOperands.reserve(genericOp.getNumDpsInputs());
 
 2100      fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs());
 
 2101      for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
 
 2102        if (inputOperand == opOperand)
 
 2104        Value inputValue = inputOperand->get();
 
 2105        fusedIndexMaps.push_back(
 
 2106            genericOp.getMatchingIndexingMap(inputOperand));
 
 2107        fusedOperands.push_back(inputValue);
 
 2108        fusedLocs.push_back(inputValue.
getLoc());
 
 2110      for (OpOperand &outputOperand : genericOp.getDpsInitsMutable())
 
 2111        fusedIndexMaps.push_back(
 
 2112            genericOp.getMatchingIndexingMap(&outputOperand));
 
 2118            genericOp, 
"fused op loop bound computation failed");
 
 2122      Value scalarConstant =
 
 2123          arith::ConstantOp::create(rewriter, def->
getLoc(), constantAttr);
 
 2125      SmallVector<Value> outputOperands = genericOp.getOutputs();
 
 2127          GenericOp::create(rewriter, rewriter.
getFusedLoc(fusedLocs),
 
 2128                            genericOp->getResultTypes(),
 
 2132                            genericOp.getIteratorTypes(),
 
 2138      Region ®ion = genericOp->getRegion(0);
 
 2143      Region &fusedRegion = fusedOp->getRegion(0);
 
 2146      rewriter.
replaceOp(genericOp, fusedOp->getResults());
 
 2164struct RemoveOutsDependency : 
public OpRewritePattern<GenericOp> {
 
 2165  using OpRewritePattern<GenericOp>::OpRewritePattern;
 
 2167  LogicalResult matchAndRewrite(GenericOp op,
 
 2168                                PatternRewriter &rewriter)
 const override {
 
 2170    bool modifiedOutput = 
false;
 
 2171    Location loc = op.getLoc();
 
 2172    for (OpOperand &opOperand : op.getDpsInitsMutable()) {
 
 2173      if (!op.payloadUsesValueFromOperand(&opOperand)) {
 
 2174        Value operandVal = opOperand.
get();
 
 2175        auto operandType = dyn_cast<RankedTensorType>(operandVal.
getType());
 
 2184        auto definingOp = operandVal.
getDefiningOp<tensor::EmptyOp>();
 
 2187        modifiedOutput = 
true;
 
 2188        SmallVector<OpFoldResult> mixedSizes =
 
 2190        Value emptyTensor = tensor::EmptyOp::create(
 
 2191            rewriter, loc, mixedSizes, operandType.getElementType());
 
 2195    if (!modifiedOutput) {
 
 2205struct FoldFillWithGenericOp : 
public OpRewritePattern<GenericOp> {
 
 2206  using OpRewritePattern<GenericOp>::OpRewritePattern;
 
 2208  LogicalResult matchAndRewrite(GenericOp genericOp,
 
 2209                                PatternRewriter &rewriter)
 const override {
 
 2210    if (!genericOp.hasPureTensorSemantics())
 
 2212    bool fillFound = 
false;
 
 2213    Block &payload = genericOp.getRegion().front();
 
 2214    for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
 
 2215      if (!genericOp.payloadUsesValueFromOperand(opOperand))
 
 2221      Value fillVal = fillOp.value();
 
 2223          cast<RankedTensorType>(fillOp.result().getType()).getElementType();
 
 2224      Value convertedVal =
 
 2239                                                    controlFoldingReshapes);
 
 2240  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(
patterns.getContext(),
 
 2241                                                        controlFoldingReshapes);
 
 2243                                                     controlFoldingReshapes);
 
 
 2250                                                      controlFoldingReshapes);
 
 2251  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
 
 2252      patterns.getContext(), controlFoldingReshapes);
 
 2254                                                     controlFoldingReshapes);
 
 
 2260  auto *context = 
patterns.getContext();
 
 2261  patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion);
 
 2262  patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant,
 
 2263               RemoveOutsDependency>(context);
 
 
 2271  patterns.add<CollapseLinalgDimensions<linalg::GenericOp>,
 
 2272               CollapseLinalgDimensions<linalg::CopyOp>>(
 
 2273      patterns.getContext(), controlCollapseDimensions);
 
 
 2288struct LinalgElementwiseOpFusionPass
 
 2290          LinalgElementwiseOpFusionPass> {
 
 2292      LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
 
 2293  void runOnOperation()
 override {
 
 2300      Operation *producer = fusedOperand->get().getDefiningOp();
 
 2301      return producer && producer->
hasOneUse();
 
 2310    affine::AffineApplyOp::getCanonicalizationPatterns(
patterns, context);
 
 2311    GenericOp::getCanonicalizationPatterns(
patterns, context);
 
 2312    tensor::ExpandShapeOp::getCanonicalizationPatterns(
patterns, context);
 
 2313    tensor::CollapseShapeOp::getCanonicalizationPatterns(
patterns, context);