24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/FloatingPointMode.h"
26 #include "llvm/ADT/SmallVector.h"
40 template <
class SrcValType,
class TargetValType,
class TargetType>
43 const std::function<TargetValType(
const SrcValType &)> &toApply,
44 TargetType targetType) {
49 for (
auto val : toTransform.
getValues<SrcValType>()) {
50 auto transformedVal = toApply(val);
51 transformedValues.push_back(transformedVal);
55 auto inShape = toTransform.
getType();
56 auto outTy = inShape.cloneWith({}, targetType);
63 const std::function<APFloat(
const APFloat &)> &toApply,
69 if (isa<FloatType>(toCheck.getType().getElementType())) {
73 "Unexpected input tensor type: the "
74 "TOSA spec only allows floats");
88 "Non-const or non-dense input tensor");
93 if (isa<ConstOp>(toCheck.getDefiningOp())) {
98 "The reciprocal can only be folded if "
99 "it operates on a TOSA constant");
106 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
110 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
122 assert(unaryOp->getNumOperands() == 1);
123 auto inputOp = unaryOp->getOperand(0);
126 if (isa<SplatElementsAttr>(values)) {
132 return inputOp.hasOneUse();
135 template <
typename RangeType>
137 ShapedType outputType,
139 using ElementType = std::decay_t<decltype(*std::begin(data))>;
141 assert(inputType.getElementType() == outputType.getElementType());
143 if (inputType.getNumElements() == 0)
146 auto inputShape = inputType.getShape();
154 auto initialValue = *std::begin(data);
159 auto srcLinearIndex = it.index();
161 uint64_t dstLinearIndex = 0;
162 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
164 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
165 srcLinearIndex /= inputShape[dim];
170 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
173 outputValues[dstLinearIndex] = it.value();
185 ShapedType outputType,
187 if (
auto data = attr.tryGetValues<
bool>())
188 return transposeType(*data, inputType, outputType, permValues);
190 if (
auto data = attr.tryGetValues<int8_t>())
191 return transposeType(*data, inputType, outputType, permValues);
193 if (
auto data = attr.tryGetValues<int16_t>())
194 return transposeType(*data, inputType, outputType, permValues);
196 if (
auto data = attr.tryGetValues<int32_t>())
197 return transposeType(*data, inputType, outputType, permValues);
199 if (
auto data = attr.tryGetValues<int64_t>())
200 return transposeType(*data, inputType, outputType, permValues);
202 if (
auto data = attr.tryGetValues<
float>())
203 return transposeType(*data, inputType, outputType, permValues);
205 if (
auto data = attr.tryGetValues<APFloat>())
206 return transposeType(*data, inputType, outputType, permValues);
211 struct TosaFoldConstantTranspose :
public OpRewritePattern<tosa::TransposeOp> {
216 auto outputType = cast<ShapedType>(op.getType());
218 if (!outputType.getElementType().isIntOrIndexOrFloat())
221 ElementsAttr inputValues;
225 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->
getUsers()))
231 auto permValues = llvm::map_to_vector(
233 permAttr.getValues<APInt>(),
234 [](
const APInt &val) { return val.getSExtValue(); });
236 auto inputType = cast<ShapedType>(op.getInput1().getType());
238 auto resultAttr =
transpose(inputValues, inputType, outputType, permValues);
241 op,
"unsupported attribute or element type");
255 auto inputTensor = recip.getInput1();
259 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
260 if (
failed(preCondCheck)) {
269 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
271 recip,
"Currently, reciprocals will only be folded if the input "
272 "tensor has a single user");
276 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
277 inputValues, &ReciprocalOp::calcOneElement,
291 int64_t remaining = index;
293 for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
294 position[i] = remaining % tensorShape[i];
295 remaining /= tensorShape[i];
306 int64_t multiplierTmp = 1;
307 for (int64_t i = position.size() - 1; i >= 0; --i) {
308 index += position[i] * multiplierTmp;
309 multiplierTmp *= tensorShape[i];
314 template <
typename OperationType>
315 llvm::APInt calculateReducedValue(
const mlir::ElementsAttr &oldTensorAttr,
317 int64_t reductionAxis,
318 int64_t reductionIndex) {
321 newShape[reductionAxis] = 1;
324 getPositionFromIndex(reductionIndex, newShape);
325 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
327 position[reductionAxis] = 0;
328 int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
329 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
331 for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
332 ++reductionAxisVal) {
334 int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
335 oldShape.end(), 1, std::multiplies<int>());
336 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
338 OperationType::calcOneElement(reducedValue, oldTensor[index]);
343 template <
typename OperationType>
344 struct ReduceConstantOptimization :
public OpRewritePattern<OperationType> {
347 bool aggressiveReduceConstant)
349 aggressiveReduceConstant(aggressiveReduceConstant) {}
355 Value inputOp = op.getInput();
360 op,
"reduce input must be const operation");
362 if (!inputOp.
hasOneUse() && !this->aggressiveReduceConstant)
364 op,
"input operation has more than one user");
366 auto resultType = cast<ShapedType>(op.getOutput().getType());
368 if (!resultType.hasStaticShape())
371 auto reductionAxis = op.getAxis();
372 const auto denseElementsAttr = constOp.getValue();
373 const auto shapedOldElementsValues =
374 cast<ShapedType>(denseElementsAttr.getType());
376 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
378 op,
"reduce input currently supported with integer type");
380 auto oldShape = shapedOldElementsValues.getShape();
381 auto newShape = resultType.getShape();
383 auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
384 std::multiplies<int>());
387 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
391 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
392 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
395 auto rankedTensorType = cast<RankedTensorType>(resultType);
401 const bool aggressiveReduceConstant;
408 bool aggressiveReduceConstant) {
409 patterns.
add<ReduceConstantOptimization<ReduceAllOp>>(
410 ctx, aggressiveReduceConstant);
411 patterns.
add<ReduceConstantOptimization<ReduceAnyOp>>(
412 ctx, aggressiveReduceConstant);
413 patterns.
add<ReduceConstantOptimization<ReduceMaxOp>>(
414 ctx, aggressiveReduceConstant);
415 patterns.
add<ReduceConstantOptimization<ReduceMinOp>>(
416 ctx, aggressiveReduceConstant);
417 patterns.
add<ReduceConstantOptimization<ReduceProdOp>>(
418 ctx, aggressiveReduceConstant);
419 patterns.
add<ReduceConstantOptimization<ReduceSumOp>>(
420 ctx, aggressiveReduceConstant);
425 patterns.
add<TosaFoldConstantTranspose>(ctx);
430 patterns.
add<TosaFoldConstantReciprocal>(ctx);
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
An attribute that represents a reference to a dense integer vector or tensor object.
MLIRContext is the top-level object for a collection of MLIR operations.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...