16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallVector.h"
25 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
27 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
32 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
34 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
39 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
41 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
56 struct GridShapeFolder
57 : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
60 LogicalResult matchAndRewrite(GridShapeOp op,
64 op.getOperation(), op.getGridAttr());
68 ArrayRef<GridAxis> opGridAxes = op.getAxes();
69 SmallVector<GridAxis> opAxesIota;
70 if (opGridAxes.empty()) {
71 opAxesIota.resize(grid.getRank());
72 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
73 opGridAxes = opAxesIota;
75 if (llvm::all_of(opGridAxes, [&grid](
GridAxis axis) {
76 return ShapedType::isDynamic(grid.getShape()[axis]);
82 SmallVector<Value> newResults(op->getResults().size());
83 SmallVector<GridAxis> newShapeOpGridAxes;
84 SmallVector<size_t> newToOldResultsIndexMap;
86 for (
size_t i = 0; i < opGridAxes.size(); ++i) {
87 auto gridAxisSize = grid.getShape()[opGridAxes[i]];
88 if (ShapedType::isDynamic(gridAxisSize)) {
89 newToOldResultsIndexMap.push_back(i);
90 newShapeOpGridAxes.push_back(opGridAxes[i]);
93 newResults[i] = arith::ConstantOp::create(
94 builder, builder.getIndexAttr(gridAxisSize));
99 if (!newShapeOpGridAxes.empty()) {
100 GridShapeOp newShapeOp =
101 GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
102 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
103 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
116 patterns.add<GridShapeFolder>(symbolTableCollection,
patterns.getContext());
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)