MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Diagnostics.h"
19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/ADT/SmallVectorExtras.h"
21#include <numeric>
22
23using namespace mlir;
24
25std::optional<SmallVector<OpFoldResult>>
27 ShapedType expandedType,
29 ArrayRef<OpFoldResult> inputShape) {
30
31 SmallVector<Value> outputShapeValues;
32 SmallVector<int64_t> outputShapeInts;
33 // For zero-rank inputs, all dims in result shape are unit extent.
34 if (inputShape.empty()) {
35 outputShapeInts.resize(expandedType.getRank(), 1);
36 return getMixedValues(outputShapeInts, outputShapeValues, b);
37 }
38
39 // Check for all static shapes.
40 if (expandedType.hasStaticShape()) {
41 ArrayRef<int64_t> staticShape = expandedType.getShape();
42 outputShapeInts.assign(staticShape.begin(), staticShape.end());
43 return getMixedValues(outputShapeInts, outputShapeValues, b);
44 }
45
46 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
47 for (const auto &it : llvm::enumerate(reassociation)) {
48 ReassociationIndices indexGroup = it.value();
49
50 int64_t indexGroupStaticSizesProductInt = 1;
51 bool foundDynamicShape = false;
52 for (int64_t index : indexGroup) {
53 int64_t outputDimSize = expandedType.getDimSize(index);
54 // Cannot infer expanded shape with multiple dynamic dims in the
55 // same reassociation group!
56 if (ShapedType::isDynamic(outputDimSize)) {
57 if (foundDynamicShape)
58 return std::nullopt;
59 foundDynamicShape = true;
60 } else {
61 outputShapeInts[index] = outputDimSize;
62 indexGroupStaticSizesProductInt *= outputDimSize;
63 }
64 }
65 if (!foundDynamicShape)
66 continue;
67
68 int64_t inputIndex = it.index();
69 // Call get<Value>() under the assumption that we're not casting
70 // dynamism.
71 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
72 Value indexGroupStaticSizesProduct =
73 arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
74 Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
75 loc, indexGroupSize, indexGroupStaticSizesProduct);
76 outputShapeValues.push_back(dynamicDimSize);
77 }
78
79 if ((int64_t)outputShapeValues.size() !=
80 llvm::count(outputShapeInts, ShapedType::kDynamic))
81 return std::nullopt;
82
83 return getMixedValues(outputShapeInts, outputShapeValues, b);
84}
85
86/// Matches a ConstantIndexOp.
87/// TODO: This should probably just be a general matcher that uses matchConstant
88/// and checks the operation for an index type.
92
93llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
95 llvm::SmallBitVector dimsToProject(shape.size());
96 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
97 if (shape[pos] == 1) {
98 dimsToProject.set(pos);
99 --rank;
100 }
101 }
102 return dimsToProject;
103}
104
106 OpFoldResult ofr) {
107 if (auto value = dyn_cast_if_present<Value>(ofr))
108 return value;
109 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
110 return arith::ConstantOp::create(
111 b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
112}
113
115 OpFoldResult ofr) {
116 if (auto value = dyn_cast_if_present<Value>(ofr))
117 return value;
118 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
119 return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
120}
121
123 Type targetType, Value value) {
124 if (targetType == value.getType())
125 return value;
126
127 bool targetIsIndex = targetType.isIndex();
128 bool valueIsIndex = value.getType().isIndex();
129 if (targetIsIndex ^ valueIsIndex)
130 return arith::IndexCastOp::create(b, loc, targetType, value);
131
132 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
133 auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
134 assert(targetIntegerType && valueIntegerType &&
135 "unexpected cast between types other than integers and index");
136 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
137
138 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
139 return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
140 return arith::TruncIOp::create(b, loc, targetIntegerType, value);
141}
142
144 IntegerType toType, bool isUnsigned) {
145 // If operand is floating point, cast directly to the int type.
146 if (isa<FloatType>(operand.getType())) {
147 if (isUnsigned)
148 return arith::FPToUIOp::create(b, toType, operand);
149 return arith::FPToSIOp::create(b, toType, operand);
150 }
151 // Cast index operands directly to the int type.
152 if (operand.getType().isIndex())
153 return arith::IndexCastOp::create(b, toType, operand);
154 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
155 // Either extend or truncate.
156 if (toType.getWidth() > fromIntType.getWidth()) {
157 if (isUnsigned)
158 return arith::ExtUIOp::create(b, toType, operand);
159 return arith::ExtSIOp::create(b, toType, operand);
160 }
161 if (toType.getWidth() < fromIntType.getWidth())
162 return arith::TruncIOp::create(b, toType, operand);
163 return operand;
164 }
165
166 return {};
167}
168
170 FloatType toType, bool isUnsigned) {
171 // If operand is integer, cast directly to the float type.
172 // Note that it is unclear how to cast from BF16<->FP16.
173 if (isa<IntegerType>(operand.getType())) {
174 if (isUnsigned)
175 return arith::UIToFPOp::create(b, toType, operand);
176 return arith::SIToFPOp::create(b, toType, operand);
177 }
178 if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
179 if (toType.getWidth() > fromFpTy.getWidth())
180 return arith::ExtFOp::create(b, toType, operand);
181 if (toType.getWidth() < fromFpTy.getWidth())
182 return arith::TruncFOp::create(b, toType, operand);
183 return operand;
184 }
185
186 return {};
187}
188
190 ComplexType targetType,
191 bool isUnsigned) {
192 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
193 if (isa<FloatType>(targetType.getElementType()) &&
194 isa<FloatType>(fromComplexType.getElementType())) {
195 Value real = complex::ReOp::create(b, operand);
196 Value imag = complex::ImOp::create(b, operand);
197 Type targetETy = targetType.getElementType();
198 if (targetType.getElementType().getIntOrFloatBitWidth() <
199 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
200 real = arith::TruncFOp::create(b, targetETy, real);
201 imag = arith::TruncFOp::create(b, targetETy, imag);
202 } else {
203 real = arith::ExtFOp::create(b, targetETy, real);
204 imag = arith::ExtFOp::create(b, targetETy, imag);
205 }
206 return complex::CreateOp::create(b, targetType, real, imag);
207 }
208 }
209
210 if (isa<FloatType>(operand.getType())) {
211 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
212 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
213 Value from = operand;
214 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
215 from = arith::ExtFOp::create(b, toFpTy, from);
216 }
217 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
218 from = arith::TruncFOp::create(b, toFpTy, from);
219 }
221 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
222 return complex::CreateOp::create(b, targetType, from, zero);
223 }
224
225 if (isa<IntegerType>(operand.getType())) {
226 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
227 Value from = operand;
228 if (isUnsigned) {
229 from = arith::UIToFPOp::create(b, toFpTy, from);
230 } else {
231 from = arith::SIToFPOp::create(b, toFpTy, from);
232 }
234 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
235 return complex::CreateOp::create(b, targetType, from, zero);
236 }
237
238 return {};
239}
240
242 Type toType, bool isUnsignedCast) {
243 if (operand.getType() == toType)
244 return operand;
245 ImplicitLocOpBuilder ib(loc, b);
247 if (auto intTy = dyn_cast<IntegerType>(toType)) {
248 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
249 } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
250 result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
251 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
252 result =
253 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
254 }
255
256 if (result)
257 return result;
258
259 emitWarning(loc) << "could not cast operand of type " << operand.getType()
260 << " to " << toType;
261 return operand;
262}
263
266 ArrayRef<OpFoldResult> valueOrAttrVec) {
267 return llvm::map_to_vector<4>(
268 valueOrAttrVec, [&](OpFoldResult value) -> Value {
269 return getValueOrCreateConstantIndexOp(b, loc, value);
270 });
271}
272
274 Type type, const APInt &value) {
275 TypedAttr attr;
276 if (isa<IntegerType>(type)) {
277 attr = builder.getIntegerAttr(type, value);
278 } else {
279 auto vecTy = cast<ShapedType>(type);
280 attr = SplatElementsAttr::get(vecTy, value);
281 }
282
283 return arith::ConstantOp::create(builder, loc, attr);
284}
285
287 Type type, int64_t value) {
288 unsigned elementBitWidth = 0;
289 if (auto intTy = dyn_cast<IntegerType>(type))
290 elementBitWidth = intTy.getWidth();
291 else
292 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
293
294 return createScalarOrSplatConstant(builder, loc, type,
295 APInt(elementBitWidth, value));
296}
297
299 Type type, const APFloat &value) {
300 if (isa<FloatType>(type))
301 return builder.createOrFold<arith::ConstantOp>(
302 loc, type, builder.getFloatAttr(type, value));
303 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
304 return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
305}
306
308 if (auto value = dyn_cast_if_present<Value>(ofr))
309 return value.getType();
310 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
311 return attr.getType();
312}
313
315 return arith::AndIOp::create(b, loc, lhs, rhs);
316}
318 if (isa<FloatType>(lhs.getType()))
319 return arith::AddFOp::create(b, loc, lhs, rhs);
320 return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
321}
323 if (isa<FloatType>(lhs.getType()))
324 return arith::SubFOp::create(b, loc, lhs, rhs);
325 return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
326}
328 if (isa<FloatType>(lhs.getType()))
329 return arith::MulFOp::create(b, loc, lhs, rhs);
330 return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
331}
333 if (isa<FloatType>(lhs.getType()))
334 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
335 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
336}
338 if (isa<FloatType>(lhs.getType()))
339 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
340 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
341}
343 return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
344}
345
346namespace mlir::arith {
347
349 return createProduct(builder, loc, values, values.front().getType());
350}
351
353 Type resultType) {
354 Value one = ConstantOp::create(builder, loc, resultType,
355 builder.getOneAttr(resultType));
356 ArithBuilder arithBuilder(builder, loc);
357 return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) {
358 return arithBuilder.mul(acc, v);
359 });
360}
361
362FloatType parseFloatType(MLIRContext *ctx, StringRef name) {
363 // Parsing non-builtin types is unsafe because the respective dialect may not
364 // have been loaded.
365 if (!name.empty() && name.front() == '!')
366 return FloatType();
367
368 // Suppress diagnostics: callers handle invalid type strings themselves.
369 ScopedDiagnosticHandler handler(ctx, [](Diagnostic &) {});
370 return dyn_cast_or_null<FloatType>(mlir::parseType(name, ctx));
371}
372
373} // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition Utils.cpp:189
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition Utils.cpp:143
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition Utils.cpp:169
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
FloatType parseFloatType(MLIRContext *ctx, StringRef name)
Definition Utils.cpp:362
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition Utils.cpp:348
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition Utils.cpp:273
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:93
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition Utils.cpp:89
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value mul(Value lhs, Value rhs)
Definition Utils.cpp:327
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:314
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:337
Value select(Value cmp, Value lhs, Value rhs)
Definition Utils.cpp:342
Value add(Value lhs, Value rhs)
Definition Utils.cpp:317
Value sgt(Value lhs, Value rhs)
Definition Utils.cpp:332
Value sub(Value lhs, Value rhs)
Definition Utils.cpp:322
The matcher that matches a certain kind of op.
Definition Matchers.h:283