23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVector.h"
38template <
class SrcValType,
class TargetValType,
class TargetType>
41 const std::function<TargetValType(
const SrcValType &)> &toApply,
42 TargetType targetType) {
47 for (
auto val : toTransform.
getValues<SrcValType>()) {
48 auto transformedVal = toApply(val);
49 transformedValues.push_back(transformedVal);
53 auto inShape = toTransform.
getType();
54 auto outTy = inShape.cloneWith({}, targetType);
61 const std::function<APFloat(
const APFloat &)> &toApply,
62 FloatType targetType);
67 if (isa<FloatType>(toCheck.getType().getElementType())) {
71 "Unexpected input tensor type: the "
72 "TOSA spec only allows floats");
86 "Non-const or non-dense input tensor");
91 if (isa<ConstOp>(toCheck.getDefiningOp())) {
96 "The reciprocal can only be folded if "
97 "it operates on a TOSA constant");
104 auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
108 return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
120 assert(unaryOp->getNumOperands() == 1);
121 auto inputOp = unaryOp->getOperand(0);
124 if (isa<SplatElementsAttr>(values)) {
130 return inputOp.hasOneUse();
133template <
typename RangeType>
135 ShapedType outputType,
137 using ElementType = std::decay_t<
decltype(*std::begin(data))>;
139 assert(inputType.getElementType() == outputType.getElementType());
141 if (inputType.getNumElements() == 0)
144 auto inputShape = inputType.getShape();
152 auto initialValue = *std::begin(data);
156 for (
const auto &it : llvm::enumerate(data)) {
157 auto srcLinearIndex = it.index();
159 uint64_t dstLinearIndex = 0;
160 for (
int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
162 auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
163 srcLinearIndex /= inputShape[dim];
168 outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
171 outputValues[dstLinearIndex] = it.value();
180std::optional<ArrayRef<T>> tryGetDenseResourceValues(ElementsAttr attr) {
181 if (
auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
193 return blob->template getDataAs<T>();
204 ShapedType outputType,
207 if (
auto data = attr.tryGetValues<
bool>())
208 return transposeType(*data, inputType, outputType, permValues);
210 if (
auto data = attr.tryGetValues<int8_t>())
211 return transposeType(*data, inputType, outputType, permValues);
213 if (
auto data = attr.tryGetValues<int16_t>())
214 return transposeType(*data, inputType, outputType, permValues);
216 if (
auto data = attr.tryGetValues<int32_t>())
217 return transposeType(*data, inputType, outputType, permValues);
219 if (
auto data = attr.tryGetValues<
int64_t>())
220 return transposeType(*data, inputType, outputType, permValues);
222 if (
auto data = attr.tryGetValues<
float>())
223 return transposeType(*data, inputType, outputType, permValues);
225 if (
auto data = attr.tryGetValues<APFloat>())
226 return transposeType(*data, inputType, outputType, permValues);
229 if (isa<DenseResourceElementsAttr>(attr)) {
230 auto elementTy = attr.getElementType();
232 if (
auto data = tryGetDenseResourceValues<bool>(attr);
233 data && elementTy.isInteger(1))
234 return transposeType(*data, inputType, outputType, permValues);
236 if (
auto data = tryGetDenseResourceValues<int8_t>(attr);
237 data && elementTy.isInteger(8))
238 return transposeType(*data, inputType, outputType, permValues);
240 if (
auto data = tryGetDenseResourceValues<int16_t>(attr);
241 data && elementTy.isInteger(16))
242 return transposeType(*data, inputType, outputType, permValues);
244 if (
auto data = tryGetDenseResourceValues<int32_t>(attr);
245 data && elementTy.isInteger(32))
246 return transposeType(*data, inputType, outputType, permValues);
248 if (
auto data = tryGetDenseResourceValues<int64_t>(attr);
249 data && elementTy.isInteger(64))
250 return transposeType(*data, inputType, outputType, permValues);
252 if (
auto data = tryGetDenseResourceValues<float>(attr);
253 data && elementTy.isF32())
254 return transposeType(*data, inputType, outputType, permValues);
260struct TosaFoldConstantTranspose :
public OpRewritePattern<tosa::TransposeOp> {
263 LogicalResult matchAndRewrite(tosa::TransposeOp op,
264 PatternRewriter &rewriter)
const override {
265 auto outputType = cast<ShapedType>(op.getType());
267 if (!outputType.getElementType().isIntOrIndexOrFloat())
270 ElementsAttr inputValues;
274 if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
277 auto permValues = llvm::map_to_vector(
278 op.getPerms(), [](
const int32_t v) { return static_cast<int64_t>(v); });
280 auto inputType = cast<ShapedType>(op.getInput1().getType());
282 auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
285 op,
"unsupported attribute or element type");
297 LogicalResult matchAndRewrite(ReciprocalOp recip,
298 PatternRewriter &rewriter)
const override {
299 auto inputTensor = recip.getInput1();
303 notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
304 if (
failed(preCondCheck)) {
309 DenseElementsAttr inputValues;
313 if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
315 recip,
"Currently, reciprocals will only be folded if the input "
316 "tensor has a single user");
320 auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
321 inputValues, &ReciprocalOp::calcOneElement,
337 for (
int64_t i = tensorShape.size() - 1; i >= 0; --i) {
338 position[i] = remaining % tensorShape[i];
339 remaining /= tensorShape[i];
351 for (
int64_t i = position.size() - 1; i >= 0; --i) {
352 index += position[i] * multiplierTmp;
353 multiplierTmp *= tensorShape[i];
358template <
typename OperationType>
359llvm::APInt calculateReducedValue(
const mlir::ElementsAttr &oldTensorAttr,
365 newShape[reductionAxis] = 1;
368 getPositionFromIndex(reductionIndex, newShape);
369 auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
371 position[reductionAxis] = 0;
372 int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
373 llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
375 for (
int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
376 ++reductionAxisVal) {
378 int64_t stride = llvm::product_of(oldShape.drop_front(reductionAxis + 1));
379 int64_t index = indexAtOldTensor + stride * reductionAxisVal;
381 OperationType::calcOneElement(reducedValue, oldTensor[
index]);
386template <
typename OperationType>
389 ReduceConstantOptimization(MLIRContext *context,
390 bool aggressiveReduceConstant)
391 : OpRewritePattern<OperationType>(context),
392 aggressiveReduceConstant(aggressiveReduceConstant) {}
394 using OpRewritePattern<OperationType>::OpRewritePattern;
396 LogicalResult matchAndRewrite(OperationType op,
397 PatternRewriter &rewriter)
const override {
398 Value inputOp = op.getInput();
403 op,
"reduce input must be const operation");
405 if (!inputOp.
hasOneUse() && !this->aggressiveReduceConstant)
407 op,
"input operation has more than one user");
409 auto resultType = cast<ShapedType>(op.getOutput().getType());
411 if (!resultType.hasStaticShape())
414 auto reductionAxis = op.getAxis();
415 const auto denseElementsAttr = constOp.getValues();
416 const auto shapedOldElementsValues =
417 cast<ShapedType>(denseElementsAttr.getType());
419 if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
421 op,
"reduce input currently supported with integer type");
423 auto oldShape = shapedOldElementsValues.getShape();
424 auto newShape = resultType.getShape();
426 int64_t newNumOfElements = llvm::product_of(newShape);
427 llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
429 for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
433 newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
434 denseElementsAttr, oldShape, reductionAxis, reductionIndex);
437 auto rankedTensorType = cast<RankedTensorType>(resultType);
443 const bool aggressiveReduceConstant;
450 bool aggressiveReduceConstant) {
451 patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
452 ctx, aggressiveReduceConstant);
453 patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
454 ctx, aggressiveReduceConstant);
455 patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
456 ctx, aggressiveReduceConstant);
457 patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
458 ctx, aggressiveReduceConstant);
459 patterns.add<ReduceConstantOptimization<ReduceProductOp>>(
460 ctx, aggressiveReduceConstant);
461 patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
462 ctx, aggressiveReduceConstant);
467 patterns.add<TosaFoldConstantTranspose>(ctx);
472 patterns.add<TosaFoldConstantReciprocal>(ctx);
This class represents a processed binary blob of data.
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
int64_t getNumElements() const
Returns the number of elements held by this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant)
void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, RewritePatternSet &patterns)
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, RewritePatternSet &patterns)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...