39 template <
typename ConcreteType>
43 std::optional<APInt> apInt;
44 std::optional<APFloat> apFloat;
46 struct APIntOrFloatArray {
50 using RegionComputationFn =
51 std::function<APIntOrFloat(
const APIntOrFloatArray &)>;
60 if (!genericOp.hasTensorSemantics())
64 if (genericOp.getNumDpsInits() != 1)
67 auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
70 if (!outputType || !outputType.hasStaticShape())
73 if (!llvm::all_of(genericOp.getInputs(), [](
Value input) {
74 return isa<ShapedType>(input.getType());
79 auto getOperandElementType = [](
Value value) {
80 return cast<ShapedType>(value.getType()).getElementType();
83 llvm::map_range(genericOp->getOperands(), getOperandElementType)))
87 auto elementType = outputType.getElementType();
88 if (!elementType.isIntOrFloat())
96 if (!llvm::all_of(genericOp.getIndexingMapsArray(),
97 [](
AffineMap map) { return map.isPermutation(); }))
100 for (
OpOperand &operand : genericOp.getDpsInitsMutable()) {
101 if (genericOp.payloadUsesValueFromOperand(&operand))
106 if (!
static_cast<const ConcreteType *
>(
this)->matchIndexingMaps(genericOp))
111 RegionComputationFn computeFn =
112 static_cast<const ConcreteType *
>(
this)->getRegionComputeFn(genericOp);
117 int numInputs = genericOp.getNumDpsInputs();
119 for (
const auto &en :
llvm::enumerate(genericOp.getDpsInputOperands())) {
127 for (
OpOperand *operand : genericOp.getDpsInputOperands()) {
128 if (!controlFn(operand))
132 auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
134 int64_t numElements = outputType.getNumElements();
141 if (isa<FloatType>(elementType))
142 fpOutputValues.resize(numElements, APFloat(0.f));
144 intOutputValues.resize(numElements);
147 auto getDimPositions = [](
AffineMap map) {
151 dims.push_back(cast<AffineDimExpr>(result).getPosition());
157 for (
int i = 0; i < numInputs; ++i)
158 inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
159 auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
160 auto outputShape = outputType.getShape();
169 uint64_t dstLinearIndex = 0;
173 APIntOrFloatArray computeFnInputs;
175 auto inputShapes = llvm::to_vector<4>(
176 llvm::map_range(genericOp.getInputs(), [](
Value value) {
177 return cast<ShapedType>(value.getType()).getShape();
183 auto computeRemappedLinearIndex = [&](
int linearIndex) {
184 int totalCount = linearIndex;
185 for (
int dim = loopBounds.size() - 1; dim >= 0; --dim) {
186 indices[dim] = totalCount % loopBounds[dim];
187 totalCount /= loopBounds[dim];
190 for (
int dim = loopBounds.size() - 1; dim >= 0; --dim) {
191 for (
int i = 0; i < numInputs; ++i)
192 srcIndices[i][dim] = indices[inputDims[i][dim]];
193 dstIndices[dim] = indices[outputDims[dim]];
196 dstLinearIndex = dstIndices.front();
197 for (
int i = 0; i < numInputs; ++i)
198 srcLinearIndices[i] = srcIndices[i].front();
200 for (
int dim = 1; dim < outputType.getRank(); ++dim) {
201 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
202 for (
int i = 0; i < numInputs; ++i)
203 srcLinearIndices[i] =
204 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
208 bool isFloat = isa<FloatType>(elementType);
211 for (
int i = 0; i < numInputs; ++i)
212 inFpRanges.push_back(inputValues[i].getValues<APFloat>());
214 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
219 for (
int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
220 computeRemappedLinearIndex(linearIndex);
223 for (
int i = 0; i < numInputs; ++i)
224 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
228 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
232 for (
int i = 0; i < numInputs; ++i)
233 inIntRanges.push_back(inputValues[i].getValues<APInt>());
235 computeFnInputs.apInts.resize(numInputs);
240 for (
int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
241 computeRemappedLinearIndex(linearIndex);
244 for (
int i = 0; i < numInputs; ++i)
245 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
249 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
266 struct FoldConstantTranspose :
public FoldConstantBase<FoldConstantTranspose> {
267 using FoldConstantBase::FoldConstantBase;
269 bool matchIndexingMaps(GenericOp genericOp)
const {
271 return genericOp.getIndexingMapsArray().size() == 2;
274 RegionComputationFn getRegionComputeFn(GenericOp genericOp)
const {
277 if (!llvm::hasSingleElement(body))
279 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
284 for (
Value yieldVal : yieldOp.getValues()) {
285 auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
286 if (!yieldArg || yieldArg.getOwner() != &body)
288 if (yieldArg.getArgNumber() != 0)
293 return [](
const APIntOrFloatArray &inputs) {
294 if (inputs.apFloats.empty())
295 return APIntOrFloat{inputs.apInts.front(), std::nullopt};
296 return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
307 patterns.
insert<FoldConstantTranspose>(context, controlFn);
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an operand of an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...