17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
27 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
29 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
32 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
34 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
36 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
39 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
41 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
43 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
58 struct MeshShapeFolder
59 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
62 LogicalResult matchAndRewrite(MeshShapeOp op,
66 op.getOperation(), op.getMeshAttr());
70 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71 SmallVector<MeshAxis> opAxesIota;
72 if (opMeshAxes.empty()) {
73 opAxesIota.resize(mesh.getRank());
74 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75 opMeshAxes = opAxesIota;
77 if (llvm::all_of(opMeshAxes, [&mesh](
MeshAxis axis) {
78 return ShapedType::isDynamic(mesh.getShape()[axis]);
84 SmallVector<Value> newResults(op->getResults().size());
85 SmallVector<MeshAxis> newShapeOpMeshAxes;
86 SmallVector<size_t> newToOldResultsIndexMap;
88 for (
size_t i = 0; i < opMeshAxes.size(); ++i) {
89 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90 if (ShapedType::isDynamic(meshAxisSize)) {
91 newToOldResultsIndexMap.push_back(i);
92 newShapeOpMeshAxes.push_back(opMeshAxes[i]);
95 newResults[i] = builder.create<arith::ConstantOp>(
96 builder.getIndexAttr(meshAxisSize));
101 if (!newShapeOpMeshAxes.empty()) {
102 MeshShapeOp newShapeOp =
103 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104 for (
size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
118 patterns.add<MeshShapeFolder>(symbolTableCollection,
patterns.getContext());
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
void populateSimplificationPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePatternWithSymbolTableCollection(SymbolTableCollection &symbolTableCollection, OpRewritePatternArgs &&...opRewritePatternArgs)