MLIR 22.0.0git
VectorPattern.h
Go to the documentation of this file.
1//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
10#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
11
14
15namespace mlir {
16
17namespace LLVM {
18namespace detail {
19// Helper struct to "unroll" operations on n-D vectors in terms of operations on
20// 1-D LLVM vectors.
22 // LLVM array struct which encodes n-D vectors.
24 // LLVM vector type which encodes the inner 1-D vector type.
26 // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
28};
29
30// For >1-D vector types, extracts the necessary information to iterate over all
31// 1-D subvectors in the underlying llrepresentation of the n-D vector
32// Iterates on the llvm array type until we hit a non-array type (which is
33// asserted to be an llvm vector type).
34NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
35 const LLVMTypeConverter &converter);
36
37// Express `linearIndex` in terms of coordinates of `basis`.
38// Returns the empty vector when linearIndex is out of the range [0, P] where
39// P is the product of all the basis coordinates.
40//
41// Prerequisites:
42// Basis is an array of nonnegative integers (signed type inherited from
43// vector shape type).
45 unsigned linearIndex);
46
47// Iterate of linear index, convert to coords space and insert splatted 1-D
48// vector in each position.
49void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
51
53 Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
54 std::function<Value(Type, ValueRange)> createOperand,
55 ConversionPatternRewriter &rewriter);
56
57LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58 ValueRange operands,
59 ArrayRef<NamedAttribute> targetAttrs,
60 Attribute propertiesAttr,
61 const LLVMTypeConverter &typeConverter,
62 ConversionPatternRewriter &rewriter);
63} // namespace detail
64} // namespace LLVM
65
66// Default attribute conversion class, which passes all source attributes
67// through to the target op, unmodified. The attribute to set properties of the
68// target operation will be nullptr (i.e. any properties that exist in will have
69// default values).
70template <typename SourceOp, typename TargetOp>
72public:
73 AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
74
75 ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
76 Attribute getPropAttr() const { return {}; }
77
78private:
80};
81
82/// Basic lowering implementation to rewrite Ops with just one result to the
83/// LLVM Dialect. This supports higher-dimensional vector types.
84/// The AttrConvert template template parameter should:
85// - be a template class with SourceOp and TargetOp type parameters
86// - have a constructor that takes a SourceOp instance
87// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
88// attributes that the target operation will have
89// - a getPropAttr() method that returns either a NULL attribute or a
90// DictionaryAttribute with properties that exist for the target operation
91template <typename SourceOp, typename TargetOp,
92 template <typename, typename> typename AttrConvert =
93 AttrConvertPassThrough,
94 bool FailOnUnsupportedFP = false>
96 : public ConvertOpToLLVMPattern<SourceOp, FailOnUnsupportedFP> {
97public:
98 using ConvertOpToLLVMPattern<SourceOp,
99 FailOnUnsupportedFP>::ConvertOpToLLVMPattern;
101
102 LogicalResult
103 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 static_assert(
106 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
107 "expected single result op");
108
109 // Bail on unsupported floating point types. (These are type-converted to
110 // integer types.)
111 if (FailOnUnsupportedFP && LLVM::detail::opHasUnsupportedFloatingPointTypes(
112 op, *this->typeConverter)) {
113 return rewriter.notifyMatchFailure(op, "unsupported floating point type");
114 }
115
116 // Determine attributes for the target op
117 AttrConvert<SourceOp, TargetOp> attrConvert(op);
118
120 op, TargetOp::getOperationName(), adaptor.getOperands(),
121 attrConvert.getAttrs(), attrConvert.getPropAttr(),
122 *this->getTypeConverter(), rewriter);
123 }
124};
125} // namespace mlir
126
127#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
AttrConvertPassThrough(SourceOp srcOp)
ArrayRef< NamedAttribute > getAttrs() const
Attribute getPropAttr() const
Attributes are known-constant values of operations.
Definition Attributes.h:25
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:222
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
VectorConvertToLLVMPattern< SourceOp, TargetOp > Super
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
bool opHasUnsupportedFloatingPointTypes(Operation *op, const TypeConverter &typeConverter)
Return "true" if the given op has any unsupported floating point types (either operands or results).
Definition Pattern.cpp:541
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< int64_t, 4 > arraySizes