MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// Tosa dialect interface includes.
43//===----------------------------------------------------------------------===//
44
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49
50namespace {
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52
53//===----------------------------------------------------------------------===//
54// Dialect Function Inliner Interface.
55//===----------------------------------------------------------------------===//
56struct TosaInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 //===--------------------------------------------------------------------===//
60 // Analysis Hooks.
61 //===--------------------------------------------------------------------===//
62
63 /// All operations can be inlined by default.
64 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65 IRMapping &map) const final {
66 return true;
67 }
68
69 /// All regions with If and While parent operators can be inlined.
70 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71 IRMapping &map) const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
74 }
75};
76
77/// This class implements the bytecode interface for the Tosa dialect.
78struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
81
82 //===--------------------------------------------------------------------===//
83 // Attributes
84
85 Attribute readAttribute(DialectBytecodeReader &reader) const override {
86 return ::readAttribute(getContext(), reader);
87 }
88
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer) const override {
91 return ::writeAttribute(attr, writer);
92 }
93
94 //===--------------------------------------------------------------------===//
95 // Types
96
97 Type readType(DialectBytecodeReader &reader) const override {
98 return ::readType(getContext(), reader);
99 }
100
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer) const override {
103 return ::writeType(type, writer);
104 }
105
106 void writeVersion(DialectBytecodeWriter &writer) const final {
107 // TODO: Populate.
108 }
109
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader) const final {
112 // TODO: Populate
113 reader.emitError("Dialect does not support versioning");
114 return nullptr;
115 }
116
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version) const final {
119 return success();
120 }
121};
122
123} // namespace
124
125//===----------------------------------------------------------------------===//
126// TOSA control flow support.
127//===----------------------------------------------------------------------===//
128
129/// Returns the while loop body.
130SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131 return {&getBodyGraph()};
132}
133
134//===----------------------------------------------------------------------===//
135// TOSA variable operator support.
136//===----------------------------------------------------------------------===//
137
139 return map_to_vector(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
141 });
142}
143
144// returns type of variable op
145RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146 Type elementType = variableOp.getType();
147 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149 return RankedTensorType::get(shape, elementType);
150}
151
152//===----------------------------------------------------------------------===//
153// Tosa dialect initialization.
154//===----------------------------------------------------------------------===//
155
156void TosaDialect::initialize() {
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160 >();
161 addOperations<
162#define GET_OP_LIST
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 >();
165 addAttributes<
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 >();
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
178}
179
180Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
181 Type type, Location loc) {
182 // Tosa dialect constants only support ElementsAttr unlike standard dialect
183 // constant which supports all attributes.
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
187 }
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
191 return nullptr;
192}
193
194//===----------------------------------------------------------------------===//
195// Parsers and printers
196//===----------------------------------------------------------------------===//
197
198namespace {
199
200ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201 DenseElementsAttr &varShapeAttr,
202 TypeAttr &typeAttr) {
203 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
205 return parser.emitError(parser.getCurrentLocation())
206 << "expected ranked type";
207
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
210 ArrayRef<int64_t> shape = shapedType.getShape();
211 Builder builder(parser.getContext());
212 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213 return success();
214 }
215 return parser.emitError(parser.getCurrentLocation())
216 << "expected shaped type";
217}
218
219} // namespace
220
221// parses the optional initial value or type for a tosa variable
222// with initial value:
223// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224//
225// without initial value:
226// tosa.variable @name : tensor<1x8xf32>
228 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229 Attribute &initialValueAttr) {
230 if (succeeded(parser.parseOptionalEqual())) {
231 if (failed(parser.parseAttribute(initialValueAttr))) {
232 return parser.emitError(parser.getCurrentLocation())
233 << "expected attribute";
234 }
235 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237 typeAttr);
238 }
239 return parser.emitError(parser.getCurrentLocation())
240 << "expected Typed attr";
241 }
242
243 initialValueAttr = nullptr;
244 Type parsedType;
245 if (failed(parser.parseColonType(parsedType))) {
246 return parser.emitError(parser.getCurrentLocation())
247 << "expected type after colon";
248 }
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250}
251
253 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254 TypeAttr typeAttr, Attribute initialValueAttr) {
255 bool needsSpace = false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 auto shape =
258 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
261 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
263 p << ": ";
264 p.printAttribute(tensorTypeAttr);
265 needsSpace = true; // subsequent attr value needs a space separator
266 }
267 if (initialValueAttr) {
268 if (needsSpace)
269 p << ' ';
270 p << "= ";
271 p.printAttribute(initialValueAttr);
272 }
273}
274
275namespace {
276
277// parse attributes with special handling for tosa enum attributes
278template <typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
280 NamedAttrList &outAttrs) {
281 llvm::StringRef name;
282 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
283 return failure();
284
285 // special handling: rounding_mode accepts a *bare* RoundingMode enum
286 // keyword.
287 llvm::StringRef kw;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name == "rounding_mode" &&
290 succeeded(parser.parseOptionalKeyword(&kw))) {
291 auto sym = symbolizeRoundingMode(kw);
292 if (!sym)
293 return parser.emitError(parser.getCurrentLocation())
294 << "invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
296 outAttrs.push_back(NamedAttribute(name, attr));
297 return success();
298 }
299 }
300 // special handling: mode accepts a *bare* ResizeMode enum keyword.
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
302 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
303 auto sym = symbolizeResizeMode(kw);
304 if (!sym)
305 return parser.emitError(parser.getCurrentLocation())
306 << "invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
308 outAttrs.push_back(NamedAttribute(name, attr));
309 return success();
310 }
311 }
312 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
313 // keyword.
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
315 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
316 auto sym = symbolizeNanPropagationMode(kw);
317 if (!sym)
318 return parser.emitError(parser.getCurrentLocation())
319 << "invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
321 outAttrs.push_back(NamedAttribute(name, attr));
322 return success();
323 }
324 }
325
326 // special handling: block_size accepts a *bare* BlockSizeMode enum
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
328 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
329 auto sym = symbolizeBlockSize(kw);
330 if (!sym)
331 return parser.emitError(parser.getCurrentLocation())
332 << "invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
334 outAttrs.push_back(NamedAttribute(name, attr));
335 return success();
336 }
337 }
338
339 // Default path: parse any normal attribute literal, including fully qualified
340 // enum keyword
341 Attribute attr;
342 return parser.parseAttribute(attr, name, outAttrs);
343}
344
345template <typename EnumType>
346ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
347 // parse operands
349 if (parser.parseCommaSeparatedList(
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
351 return failure();
352
353 // Parse { attr-dict } with special handling for enum bare token
354 NamedAttrList attrs;
355 if (succeeded(parser.parseOptionalLBrace()) &&
356 failed(parser.parseOptionalRBrace())) {
357 do {
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
359 return failure();
360 } while (succeeded(parser.parseOptionalComma()));
361 if (parser.parseRBrace())
362 return failure();
363 }
364
365 FunctionType fnTy;
366 if (parser.parseColonType(fnTy))
367 return failure();
368
369 // Resolve operands and types
370 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
371 parser.getCurrentLocation(),
372 result.operands)))
373 return failure();
374
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
377
378 return success();
379}
380
381void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
382 parser << namedAttr.getName().strref() << " = ";
383 auto attr = namedAttr.getValue();
384 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 } else if (auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
393 } else {
394 parser.printAttribute(attr);
395 }
396}
397
398// print with special handling for default valued NanPropagationMode attribute
399void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
400 parser << " ";
401 parser.printOperands(op->getOperands());
402
403 NamedAttrList toPrint(op->getAttrs());
404 // remove default NanPropagate attribute
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
406 for (auto attr : op->getAttrs()) {
407 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
409 // elide from toPrint
410 toPrint.erase(attr.getName());
411 break;
412 }
413 }
414 }
415
416 if (!toPrint.empty()) {
417 parser << " {";
418 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
420 });
421 parser << "}";
422 }
423
424 parser << " : ";
425 parser.printFunctionalType(op);
426}
427
428// print with special handling for enums: RoundingMode, ResizeMode
429void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
430 parser << " ";
431 parser.printOperands(op->getOperands());
432
433 if (!op->getAttrs().empty()) {
434 parser << " {";
435 llvm::interleaveComma(op->getAttrs(), parser,
436 [&](const NamedAttribute namedAttr) {
437 printNamedAttr(parser, namedAttr);
438 });
439 parser << "}";
440 }
441
442 parser << " : ";
443 parser.printFunctionalType(op);
444}
445
446} // namespace
447
448ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
449 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
450}
451
452void RescaleOp::print(OpAsmPrinter &parser) {
453 printWithEnumHandling(parser, *this);
454}
455
456ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
457 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
458}
459
460void ApplyScaleOp::print(OpAsmPrinter &parser) {
461 printWithEnumHandling(parser, *this);
462}
463
464ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
465 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
466}
467
468void ResizeOp::print(OpAsmPrinter &parser) {
469 printWithEnumHandling(parser, *this);
470}
471
472ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
474}
475
476void ArgMaxOp::print(OpAsmPrinter &parser) {
477 printWithNanPropagationHandling(parser, *this);
478}
479
480ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
482}
483
484void MaxPool2dOp::print(OpAsmPrinter &parser) {
485 printWithNanPropagationHandling(parser, *this);
486}
487
488ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
490}
491
492void ClampOp::print(OpAsmPrinter &parser) {
493 printWithNanPropagationHandling(parser, *this);
494}
495
496ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
498}
499
500void MaximumOp::print(OpAsmPrinter &parser) {
501 printWithNanPropagationHandling(parser, *this);
502}
503
504ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
506}
507
508void MinimumOp::print(OpAsmPrinter &parser) {
509 printWithNanPropagationHandling(parser, *this);
510}
511
512ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
514}
515
516void ReduceMaxOp::print(OpAsmPrinter &parser) {
517 printWithNanPropagationHandling(parser, *this);
518}
519
520ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
522}
523
524void ReduceMinOp::print(OpAsmPrinter &parser) {
525 printWithNanPropagationHandling(parser, *this);
526}
527
528ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
531}
532
533void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
534 printWithEnumHandling(parser, *this);
535}
536
537ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
540}
541
542void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *this);
544}
545
546ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
549}
550
551void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
552 printWithEnumHandling(parser, *this);
553}
554
555ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
558}
559
560void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
561 printWithEnumHandling(parser, *this);
562}
563
564//===----------------------------------------------------------------------===//
565// Tosa utilities.
566//===----------------------------------------------------------------------===//
567
568static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
569 if (lhs % rhs != 0)
570 return std::nullopt;
571 return lhs / rhs;
572}
573
575 auto srcType = getElementTypeOrSelf(type);
576 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
577 srcType = getStorageElementTypeFromQuantized(quantType);
578 return srcType;
579}
580
584
585static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
586 Value valZp, StringRef name) {
588 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
589
590 bool bothInts =
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
592 bool sameBitWidth =
593 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
594
595 if (!bothInts || !sameBitWidth) {
596 return op->emitOpError()
597 << "expected " << name << " and " << name
598 << "_zp to both be integer of the same bitwidth, but got " << eType
599 << " vs. " << eZpType;
600 }
601 return success();
602}
603
604// Create a pad-const const tensor with value of `val` of required data-type
606 Value src, int32_t val) {
607 const auto srcType = getElementTypeOrSelf(src);
608 const auto srcElemType = getStorageElementTypeOrSelf(src);
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
613 ? DenseElementsAttr::get(padConstEType,
614 builder.getFloatAttr(srcElemType, val))
615 : DenseElementsAttr::get(padConstEType,
616 builder.getIntegerAttr(srcElemType, val))};
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
618}
619
621 if (dyn_cast<tosa::mxint8Type>(type))
622 return 8;
623 return type.getIntOrFloatBitWidth();
624}
625
626// Update dim size if current dim is dynamic, otherwise raise an error if sizes
627// do not match
628LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
629 const int64_t newDim,
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
633 currDim = newDim;
634 return success();
635 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
636 return op->emitOpError("expected ")
637 << dimName << " of " << operandName << " to match size " << currDim
638 << ", got " << newDim;
639 }
640 return success();
641}
642
644 Operation *op, const int64_t inputSize, const int64_t kernelSize,
645 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
646 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
650 return success();
651
652 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
653
654 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
656 stride);
657 if (!calculatedOutSizeMinusOne.has_value())
658 return op->emitOpError("expected input_")
659 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
660 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
661 << dimAxis << " to be wholly divisible by stride_" << dimAxis
662 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
663 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
664 << ") / " << stride;
665
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
668 return op->emitOpError("calculated output ")
669 << dimName << " did not match expected: "
670 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
671
672 return success();
673}
674
675//===----------------------------------------------------------------------===//
676// TOSA Operator Verifiers.
677//===----------------------------------------------------------------------===//
678
679template <typename T>
680static LogicalResult verifyConvOp(T op) {
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
683
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
686 auto biasEType =
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
688 auto resultEType =
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
692
693 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
694 inputEType = getStorageElementTypeFromQuantized(quantType);
695
696 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
697 weightEType = getStorageElementTypeFromQuantized(quantType);
698
699 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
700 biasEType = getStorageElementTypeFromQuantized(quantType);
701
702 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
703 resultEType = getStorageElementTypeFromQuantized(quantType);
704
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
706 // for now, only enforce bias element type == result element type for
707 // float types.
708 op.emitOpError(
709 "expect both bias and result to have same element type, got ")
710 << biasEType << " and " << resultEType;
711 return failure();
712 }
713
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
717 op.emitOpError(
718 "expect both input and weight to have same element type, got ")
719 << inputEType << " and " << weightEType;
720 return failure();
721 }
722 }
723
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
726
727 // Either both must be float or both non-float.
728 if (inputIsFloat != weightIsFloat) {
729 op.emitOpError(
730 "expect both input and weight to be float or not together, got ")
731 << inputEType << " and " << weightEType;
732 return failure();
733 }
734
735 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
736 if (inputEType != inputZpEType) {
737 return op.emitOpError("expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType << " and " << inputZpEType;
740 }
741
742 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
743 if (weightEType != weightZpEType) {
744 return op.emitOpError("expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType << " and " << weightZpEType;
747 }
748
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
751 return failure();
752
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
755 return failure();
756
757 return success();
758}
759
760LogicalResult tosa::ConstOp::verify() {
761
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
764
765 if (!attrType || !outputType) {
766 emitOpError("expected tensors for attr/result type");
767 return failure();
768 }
769
770 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
772 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
773 return success();
774 }
775
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError("expected same attr/result element types");
778 return failure();
779 }
780
781 return success();
782}
783
784template <typename T>
785static LogicalResult verifyConvOpModes(T op) {
786 auto inputEType =
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
788
789 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
790 inputEType = getStorageElementTypeFromQuantized(quantType);
791
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
795 << accType;
796
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
799 << accType;
800
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
804 << accType;
805
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
809 << accType;
810
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
813 << accType;
814
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
817 << accType;
818
819 auto resultEType =
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
821
822 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
823 resultEType = getStorageElementTypeFromQuantized(quantType);
824
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// ERROR_IF functions.
830// ERROR_IF is a predicate that must set an error if the condition holds.
831//===----------------------------------------------------------------------===//
832
833template <typename T>
834static LogicalResult verifyConvOpErrorIf(T op) {
835 llvm::ArrayRef<int64_t> padding = op.getPad();
836 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
837 return op.emitOpError("expect all padding values to be >= 0, got ")
838 << padding;
839
840 llvm::ArrayRef<int64_t> strides = op.getStride();
841 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
842 return op.emitOpError("expect all stride values to be >= 1, got ")
843 << strides;
844
845 llvm::ArrayRef<int64_t> dilations = op.getDilation();
846 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
847 return op.emitOpError("expect all dilation values to be >= 1, got ")
848 << dilations;
849
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
852 if (!outputType)
853 // Skip following checks if output is not ranked
854 return success();
855
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
860
861 if (inputType && weightType) {
862 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
864 if (failed(verifyConvOutputSize(
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0], "height", "y", "top", "bottom")))
868 return failure();
869
870 if (failed(verifyConvOutputSize(
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1], "width", "x", "left", "right")))
874 return failure();
875 }
876
877 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
879 if (failed(verifyConvOutputSize(
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0], "height", "y", "top", "bottom")))
883 return failure();
884
885 if (failed(verifyConvOutputSize(
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1], "width", "x", "left", "right")))
889 return failure();
890 }
891
892 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
894 if (failed(verifyConvOutputSize(
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0], "depth", "d", "front", "back")))
898 return failure();
899
900 if (failed(verifyConvOutputSize(
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1], "height", "y", "top", "bottom")))
904 return failure();
905
906 if (failed(verifyConvOutputSize(
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2], "width", "x", "left", "right")))
910 return failure();
911 }
912 }
913
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
916 if (!biasType)
917 // Skip following checks if bias is not ranked
918 return success();
919
920 const int64_t biasChannels = biasType.getDimSize(0);
921 const int64_t outputChannels =
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
925 // Skip following checks if biasChannels or outputChannels is dynamic dim
926 return success();
927
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels << ") or 1, got " << biasChannels;
932
933 return success();
934}
935
936// Verify whether same type and shape of the given two types.
937static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
938 StringRef name1, Type type2,
939 StringRef name2) {
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
943 return failure();
944
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
948 return op->emitOpError()
949 << "require same element type for " << name1 << " (" << elemType1
950 << ") and " << name2 << " (" << elemType2 << ")";
951
952 if (failed(verifyCompatibleShape(type1, type2)))
953 return op->emitOpError()
954 << "require same shapes for " << name1 << " (" << type1 << ") and "
955 << name2 << " (" << type2 << ")";
956
957 return success();
958}
959
960// Verify whether same length, type, and shape of the given two tensor lists.
961static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
962 StringRef name1,
963 ValueRange list2,
964 StringRef name2) {
965 if (list1.size() != list2.size())
966 return op->emitOpError()
967 << "require same number of values in " << name1 << " ("
968 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
969
970 for (auto [type1, type2] :
971 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
972 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
973 return failure();
974 }
975
976 return success();
977}
978
979static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
980 ShapeAdaptor shapeAdaptor(type);
981 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
982 return success();
983
984 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
985}
986
987template <typename T>
988static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
989 Operation *symTableOp =
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
991 if (!symTableOp)
992 // If the operation is not the scope of a symbol table, we cannot
993 // verify it against it's declaration.
994 return success();
995
996 SymbolTable symTable(symTableOp);
997 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
998
999 // Verify prior declaration
1000 if (!varOp)
1001 return op->emitOpError("'")
1002 << op.getName() << "' has not been declared by 'tosa.variable'";
1003
1004 // Verify type and shape
1005 auto variableType = getVariableType(varOp);
1006 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1007 "the input tensor")
1008 .failed())
1009 return failure();
1010 return success();
1011}
1012
1013// verify that inType and outType have same element types
1014template <typename T>
1015static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1016 StringRef aName = "input",
1017 StringRef bName = "output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1020 if (!aTType) {
1021 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1022 return failure();
1023 }
1024 if (!bTType) {
1025 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1026 return failure();
1027 }
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1030 auto aQuantType =
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1032 auto bQuantType =
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1037 // only check if both element types are int/index/float/UniformQuantized
1038 // eg, not sure how to check quant::QuantizedType
1039 // this happens in test_conv2d_q_grouped_convolution in
1040 // tfl-to-tosa-pipeline.mlir
1041 op.emitOpError("expect ")
1042 << aName << " and " << bName << " to have same element type, got "
1043 << aElementType << " and " << bElementType;
1044 return failure();
1045 }
1046 return success();
1047}
1048
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1051
1052 // Ensure output is of 32-bit integer
1053 if (const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError("result tensor is not of integer type");
1056
1057 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1058 if (!inputType.hasRank())
1059 return success();
1060
1061 // Ensure axis is within the tensor rank
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError("specified axis is outside the rank of the tensor");
1065
1066 if (!resultType.hasRank())
1067 return success();
1068
1069 const ArrayRef<int64_t> inputShape = inputType.getShape();
1070 const ArrayRef<int64_t> outputShape = resultType.getShape();
1071 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1073 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1074 return emitOpError("expected output shape '")
1075 << expectedOutputShape << "', got '" << outputShape << "'";
1076
1077 return success();
1078}
1079
1080template <typename T>
1081static LogicalResult verifyPoolingOp(T op) {
1082 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1083 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1084 return op.emitOpError("expect all kernel values to be >= 1, got ")
1085 << kernel;
1086
1087 const llvm::ArrayRef<int64_t> strides = op.getStride();
1088 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1089 return op.emitOpError("expect all stride values to be >= 1, got ")
1090 << strides;
1091
1092 const llvm::ArrayRef<int64_t> padding = op.getPad();
1093 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1094 return op.emitOpError("expect all padding values to be >= 0, got ")
1095 << padding;
1096
1097 // Padding must be less than kernel size to avoid a divide-by-zero
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError("expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1105
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError("expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop << ", pad_bottom=" << padBottom
1113 << ", kernel_y=" << kernelY;
1114
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1120 return success();
1121
1122 const auto verifyOutputSize =
1123 [&op](const int64_t inputSize, const int64_t outputSize,
1124 const int64_t kernelSize, const int64_t strideSize,
1125 const int64_t padBefore, const int64_t padAfter,
1126 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1130 return success();
1131
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError("expected input_")
1136 << dimName << " + pad_" << padBeforeName << " + pad_"
1137 << padAfterName << " - kernel_" << dimAxis
1138 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1139 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1140 << kernelSize << ") / " << strideSize;
1141
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError("calculated output ")
1145 << dimName << " did not match expected: " << "calculated="
1146 << calculatedOutSize << ", expected=" << outputSize;
1147
1148 return success();
1149 };
1150
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height", "y", "top", "bottom")))
1154 return failure();
1155
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width", "x", "left", "right")))
1159 return failure();
1160
1161 return success();
1162}
1163
1164LogicalResult tosa::AvgPool2dOp::verify() {
1165 if (failed(verifyPoolingOp(*this)))
1166 return failure();
1167
1168 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1169 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1170 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1171 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1172
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError("accumulator type for integer tensor is not i32");
1176
1177 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1179
1180 if (inputETy.isBF16() && !accType.isF32())
1181 return emitOpError("accumulator type for bf16 tensor is not f32");
1182
1183 if (inputETy.isF32() && !accType.isF32())
1184 return emitOpError("accumulator type for f32 tensor is not f32");
1185
1186 if (inputETy != inputZpETy)
1187 return emitOpError("expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy << " and " << inputZpETy;
1190
1191 if (resultETy != outputZpETy)
1192 return emitOpError("expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy << " and " << outputZpETy;
1195
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1198 return failure();
1199
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1202 return failure();
1203
1204 return success();
1205}
1206
1207LogicalResult tosa::ClampOp::verify() {
1208 mlir::Type inputETy =
1209 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1210 if (auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1212 inputETy = getStorageElementTypeFromQuantized(quantType);
1213 }
1214 mlir::Type outputETy =
1215 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1216 if (auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1218 outputETy = getStorageElementTypeFromQuantized(quantType);
1219 }
1220 if (inputETy != outputETy)
1221 return emitOpError("input/output element types are incompatible.");
1222
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1225
1226 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1227
1228 if (inputETy.isInteger(dataTypeBitWidth)) {
1229 // if input datatype is integer, check that the min_val/max_val attributes
1230 // are integer attributes, and that their type is the same as the input's
1231 // datatype
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError("min/max attributes types are incompatible with "
1238 "input/output element types.");
1239
1240 const bool isUnsigned = inputETy.isUnsignedInteger();
1241 const bool isBoolean = inputETy.isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError("expected min_val <= max_val, got min_val=")
1246 << minValAttr << ", max_val=" << maxValAttr;
1247 } else {
1248 // otherwise, input datatype is float, check that the min_val/max_val
1249 // attributes share the same type and that their type is the same as the
1250 // input's datatype
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError("min/max attributes types are incompatible with "
1257 "input/output element types.");
1258
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr << ", max_val=" << maxValAttr;
1264
1265 if (maxVal < minVal)
1266 return emitOpError("expected min_val <= max_val, got min_val=")
1267 << minValAttr << ", max_val=" << maxValAttr;
1268 }
1269
1270 return success();
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// TOSA Operator Quantization Builders.
1275//===----------------------------------------------------------------------===//
1276
1277/// This builder is called on all convolution operators except TransposeConv,
1278/// which has specialized output shape semantics. The builder also defines the
1279/// bitwidth of the output given the bit width of the input & weight content.
1281 Type outputType, Value input, Value weight,
1282 Value bias, DenseI64ArrayAttr pad,
1283 DenseI64ArrayAttr stride,
1284 DenseI64ArrayAttr dilation,
1285 TypeAttr accType) {
1286 auto zps = createZPsAsConst(builder, input, weight);
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute("pad", pad);
1289 result.addAttribute("stride", stride);
1290 result.addAttribute("dilation", dilation);
1291 result.addAttribute("acc_type", accType);
1292 Type finalOutputType = outputType;
1293 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1294 if (quantAttr) {
1295 finalOutputType =
1296 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1297 }
1298 result.addTypes(finalOutputType);
1299}
1300
1301/// Handles tosa.transpose_conv2d which has outpad and output shape
1302/// attributes.
1303static void
1305 Type outputType, Value input, Value weight,
1306 Value bias, DenseI64ArrayAttr outpad,
1307 DenseI64ArrayAttr stride, TypeAttr accType) {
1308 auto zps = createZPsAsConst(builder, input, weight);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute("out_pad", outpad);
1311 result.addAttribute("stride", stride);
1312 result.addAttribute("acc_type", accType);
1313 Type finalOutputType = outputType;
1314 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1315 if (quantAttr) {
1316 finalOutputType =
1317 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1318 }
1319 result.addTypes(finalOutputType);
1320}
1321
1322/// The tosa.matmul op is also intended to be generated where a fully_connected
1323/// op must be constructed where the weight is not a constant. In this case,
1324/// the fully_connected op must be expressed using matmul.
1325/// TODO: Add link to the leglization document explaining this.
1327 OperationState &result, Type outputType,
1328 Value a, Value b) {
1329 auto zps = createZPsAsConst(builder, a, b);
1330 result.addOperands({a, b, zps.first, zps.second});
1331
1332 Type finalOutputType{outputType};
1333 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1334 auto eType = getStorageElementTypeOrSelf(a.getType());
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1336
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType && "Output must be a shaped type");
1339
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1342 accElementType = builder.getIntegerType(48);
1343 else
1344 accElementType = builder.getI32Type();
1345
1346 finalOutputType = outputShapedType.clone(accElementType);
1347 }
1348 result.addTypes(finalOutputType);
1349}
1350
1351/// Both the tosa.avg_pool2d and unary ops use the same
1352/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1353/// has additional parameters not part of the unary ops.
1354static void
1356 Type outputType, Value input,
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1359 const Location loc{result.location};
1360 int64_t inputZp{0};
1361 int64_t outputZp{0};
1362
1363 if (auto quantAttr =
1364 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1367 }
1368 const std::optional<Value> inputZpOp =
1369 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1370 if (!inputZpOp) {
1371 (void)emitError(
1372 loc,
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1374 }
1375 const std::optional<Value> outputZpOp =
1376 createZeroPointTensor(builder, loc, outputType, outputZp);
1377 if (!outputZpOp) {
1378 (void)emitError(loc, "Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1380 }
1381
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1384 } else {
1385 // failed to create one or more zero points above: just add input as
1386 // operands this will trigger error in building the op because of missing
1387 // zero points
1388 result.addOperands({input});
1389 }
1390 result.addAttribute("kernel", kernel);
1391 result.addAttribute("stride", stride);
1392 result.addAttribute("pad", pad);
1393 result.addAttribute("acc_type", accType);
1394 result.types.push_back(outputType);
1395}
1396
1397/// This builder is called on single-parameter negate operator
1398/// to construct input and output zero points based on their
1399/// types.
1401 OperationState &result, Type outputType,
1402 Value input) {
1403 const Location loc{result.location};
1404 int64_t input1Zp{0};
1405 int64_t outputZp{0};
1406 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1407 if (quantAttr) {
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1410 }
1411 const std::optional<Value> input1ZpOp =
1412 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1413 if (!input1ZpOp) {
1414 (void)emitError(
1415 loc, "Failed to create input1 zero point for quantized NEGATE op");
1416 }
1417
1418 const std::optional<Value> outputZpOp =
1419 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1420 if (!outputZpOp) {
1421 (void)emitError(
1422 loc, "Failed to create output zero point for quantized NEGATE op");
1423 }
1424
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1427 } else {
1428 // failed to create one or more zero points above: just add input as
1429 // operands. This will trigger error in building the op because of
1430 // missing zero points
1431 result.addOperands({input});
1432 }
1433
1434 result.types.push_back(outputType);
1435}
1436
1437/// This builder is called on TOSA pad operator that needs to create its own
1438/// OptionalAttr quantization_attr parameter to scale the padding values
1439/// correctly. No pad_const is interpreted as zero-padding.
1441 Type outputType, Value input,
1442 Value paddings) {
1443 const Location loc{result.location};
1444 int32_t zp{0};
1445 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1446 if (quantAttr) {
1447 zp = static_cast<int32_t>(quantAttr.getInputZp());
1448 }
1449 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1452}
1453
1455 StringRef name, Type variableType,
1456 Attribute initialValue) {
1457 const Location loc{result.location};
1458 auto nameAttr = builder.getStringAttr(name);
1459
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1461 if (!shapedType) {
1462 (void)emitError(loc, "variable type must be a shaped type");
1463 return;
1464 }
1465 if (!shapedType.hasRank()) {
1466 (void)emitError(loc, "variable type must be a ranked type");
1467 return;
1468 }
1469
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1472 ArrayRef<int64_t> shape = shapedType.getShape();
1473 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1474
1475 result.addAttribute("sym_name", nameAttr);
1476 result.addAttribute("var_shape", varShapeAttr);
1477 result.addAttribute("type", elementTypeAttr);
1478 result.addAttribute("initial_value", initialValue);
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// TOSA Operator Return Type Inference.
1483//===----------------------------------------------------------------------===//
1484static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1485 const int64_t dim2) {
1486 if (dim1 == 1)
1487 return dim2;
1488 if (dim2 == 1)
1489 return dim1;
1490
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1492 return failure();
1493
1494 // Prefer static dimension over dynamic
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1496}
1497
1498static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1499 SmallVector<int64_t> &outShape) {
1500 int64_t outRank = 0;
1501 for (int i = 0, e = operands.size(); i != e; ++i) {
1502 auto shape = operands.getShape(i);
1503 if (!shape.hasRank()) {
1504 // TODO(jennik): Update function to have better case handling for
1505 // invalid operands and for ranked tensors.
1506 return failure();
1507 }
1508 outRank = std::max<int64_t>(outRank, shape.getRank());
1509 }
1510
1511 outShape.resize(outRank, 1);
1512
1513 for (int i = 0, e = operands.size(); i != e; ++i) {
1514 auto shape = operands.getShape(i);
1515 auto rankDiff = outShape.size() - shape.getRank();
1516
1517 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 = shape.getDimSize(i);
1520
1521 const FailureOr<int64_t> maybeResolvedDim =
1522 resolveBroadcastDim(dim1, dim2);
1523 if (failed(maybeResolvedDim))
1524 return failure();
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1527 }
1528 }
1529
1530 return success();
1531}
1532
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1536 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1537 ShapeAdaptor inputShape(adaptor.getInput().getType());
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1540
1541 if (!inputShape.hasRank()) {
1542 inferredReturnShapes.push_back(ShapedTypeComponents());
1543 return success();
1544 }
1545
1546 SmallVector<int64_t> outShape;
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1549 if (i == axisVal)
1550 continue;
1551 outShape.push_back(inputShape.getDimSize(i));
1552 }
1553
1554 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1555 return success();
1556}
1557
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1561 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1563
1564 if (!inputShape.hasRank())
1565 return failure();
1566
1567 llvm::SmallVector<int64_t> outputShape;
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1572
1573 // Note that we can support this calculation symbolically
1574 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1577
1578 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1579 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1580
1581 return success();
1582}
1583
1584static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1587 if (!isPowerOfTwo)
1588 return op->emitOpError("expected ")
1589 << dimName << " to be a power of two, got " << dimSize;
1590
1591 return success();
1592}
1593
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1596 if (failed(verifyCompatibleShapes(outputTypes)))
1597 return emitOpError("expected output shapes to match, got ") << outputTypes;
1598
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1601 if (!inputType)
1602 return success();
1603
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1606 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1607 return failure();
1608
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1611 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1612 return failure();
1613
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1615 if (!outputType)
1616 return success();
1617
1618 // Batch and height input/output dimensions should match
1619 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1620 outputType.getShape().drop_back())))
1621 return emitOpError("expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType << " output=" << outputType;
1624
1625 // Output width dimension expected to be input_width / 2 + 1
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1629 return emitOpError(
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1631 << outputWidth;
1632
1633 return success();
1634}
1635
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1639 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1640 inferredReturnShapes.push_back(
1641 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1642 inferredReturnShapes.push_back(
1643 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1644 return success();
1645}
1646
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1652 if (!inputRealType || !inputImagType)
1653 return success();
1654
1655 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a : b;
1657 };
1658
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1662 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1663 return failure();
1664
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1668 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1669 return failure();
1670
1671 return success();
1672}
1673
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1677 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1678 // Infer all dimension sizes by reducing based on inputs.
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1681 llvm::SmallVector<int64_t> outputShape;
1682 bool hasRankedInput = false;
1683 for (auto operand : adaptor.getOperands()) {
1684 ShapeAdaptor operandShape(operand.getType());
1685 if (!operandShape.hasRank())
1686 continue;
1687
1688 // Copy the Operand's rank.
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1691
1692 // Copy shapes until the dim is non-dynamic.
1693 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1695 continue;
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1699 return emitOptionalError(location,
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1702 i);
1703 }
1704
1705 hasRankedInput = true;
1706 }
1707
1708 if (adaptor.getInput1().empty())
1709 return failure();
1710
1711 Type inputType =
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1714 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1715 return success();
1716 }
1717
1718 // Determine the dimension size along the concatenation axis.
1719 int64_t concatDimSize = 0;
1720 for (auto operand : adaptor.getOperands()) {
1721 ShapeAdaptor operandShape(operand.getType());
1722
1723 // We need to know the length of the concatenation axis of all inputs to
1724 // determine the dimension size of the output shape.
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1727 break;
1728 }
1729
1730 concatDimSize += operandShape.getDimSize(axis);
1731 }
1732
1733 outputShape[axis] = concatDimSize;
1734
1735 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1736 return success();
1737}
1738
1739LogicalResult tosa::ConcatOp::verify() {
1740 // check that each input has same element type as output
1741 auto outType = getOutput().getType();
1742 const Operation::operand_range inputList = getInput1();
1743
1744 // Check there is at least one input
1745 if (inputList.empty())
1746 return emitOpError("expect at least one input");
1747
1748 if (!llvm::all_of(inputList, [&](auto input) {
1749 return succeeded(verifySameElementTypes(
1750 *this, /* inType = */ input.getType(), outType));
1751 })) {
1752 return failure();
1753 }
1754
1755 const int32_t axis = getAxis();
1756 ShapeAdaptor firstRankedInputShape = nullptr;
1757 for (const auto &input : inputList) {
1758 const Type inputType = input.getType();
1759 ShapeAdaptor currShape(inputType);
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1762 // Check axis is in expected range
1763 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1764 return emitOpError("expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1766 << axis;
1767 break;
1768 }
1769 }
1770
1771 const auto allOperandsHasRank = [](const Value input) {
1772 return ShapeAdaptor(input.getType()).hasRank();
1773 };
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1775 const int64_t firstInputRank = firstRankedInputShape.getRank();
1776
1777 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1778 const ShapeAdaptor inputShape(input.getType());
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum = index + 1;
1781
1782 // Check that each operand has the same rank
1783 if (inputRank != firstInputRank)
1784 return emitOpError(
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1787 << operandNum;
1788
1789 // Check non-axis dims match
1790 for (int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1792 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1795 continue;
1796 if (inputDim != firstInputDim)
1797 return emitOpError("expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim << " vs " << firstInputDim << " at index " << i
1800 << " on operands 0 and " << operandNum;
1801 }
1802 }
1803
1804 const ShapeAdaptor outputShape(outType);
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError("expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() << " vs " << firstInputRank;
1808
1809 // ERROR_IF(axis_sum != shape[axis]);
1810 int64_t axisSum = 0;
1811 for (const auto &input : inputList) {
1812 const ShapeAdaptor inputShape(input.getType());
1813 if (inputShape.isDynamicDim(axis)) {
1814 // make axisSum negative to indicate invalid value
1815 axisSum = -1;
1816 break;
1817 }
1818 axisSum += inputShape.getDimSize(axis);
1819 }
1820
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError("requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum << " and " << outputShape.getDimSize(axis);
1827 }
1828
1829 return success();
1830}
1831
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1834 ValueShapeRange operands, DictionaryAttr attributes,
1835 OpaqueProperties properties, RegionRange regions,
1836 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1837 auto elementType = IntegerType::get(context, /*width=*/1);
1838
1840 if (resolveBroadcastShape(operands, outShape).failed()) {
1841 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1842 return success();
1843 }
1844
1845 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1846 return success();
1847}
1848
1849bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1850 if (l.size() != r.size() || l.size() != 1)
1851 return false;
1852 return succeeded(verifyCompatibleShape(l[0], r[0]));
1853}
1854
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1858 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1859 ShapeAdaptor lhsShape(adaptor.getA().getType());
1860 ShapeAdaptor rhsShape(adaptor.getB().getType());
1861
1862 // All shapes are dynamic.
1863 SmallVector<int64_t> outShape;
1864 outShape.resize(3, ShapedType::kDynamic);
1865
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1869 }
1870
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1873 : outShape[0];
1874 outShape[2] = rhsShape.getDimSize(2);
1875 }
1876
1877 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1878 return success();
1879}
1880
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1884
1885 // Must be shaped tensor types
1886 if (!aType)
1887 return emitOpError("expect a shaped tensor for input a, got ")
1888 << getA().getType();
1889
1890 if (!bType)
1891 return emitOpError("expect a shaped tensor for input b, got ")
1892 << getB().getType();
1893
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1896
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1901
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError("expect operands to be both quantized or both not "
1905 "quantized, got ")
1906 << aElementType << " and " << bElementType;
1907 }
1908 // both a and b have quantized element types
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError("expect quantized operands to have same widths, got ")
1913 << aQuantWidth << " and " << bQuantWidth;
1914 }
1915 }
1916
1917 // check a_zp and b_zp
1918 auto aEType = getStorageElementTypeOrSelf(aType);
1919 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1920 if (aEType != aZpEType) {
1921 return emitOpError("expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType << " and " << aZpEType;
1924 }
1925
1926 auto bEType = getStorageElementTypeOrSelf(bType);
1927 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1928 if (bEType != bZpEType) {
1929 return emitOpError("expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType << " and " << bZpEType;
1932 }
1933
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1936 return failure();
1937
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1940 return failure();
1941
1942 return success();
1943}
1944
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1948 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1949 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1950
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1955 }
1956
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1960 : outShape[0];
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1962 : outShape[1];
1963 }
1964
1965 // If B batch size is 1, it is broadcast across A's batch size
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1970 outShape[0] =
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1973 }
1974
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1979 outShape[0] =
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1982 : outShape[2];
1983 }
1984
1985 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1986 return success();
1987}
1988
1989LogicalResult MatmulTBlockScaledOp::verify() {
1990 // Verify same input data types
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1993 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1994 "B_data")))
1995 return failure();
1996
1997 // Verify input shape compatibility
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2001 int64_t W = ShapedType::kDynamic;
2002 int64_t C = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2004
2005 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2006 if (aDataShape.hasRank()) {
2007 N = aDataShape.getDimSize(0);
2008 H = aDataShape.getDimSize(1);
2009 C = aDataShape.getDimSize(2);
2010 }
2011
2012 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2013 if (aScaleShape.hasRank()) {
2014 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2015 "a_scale", "batch")) ||
2016 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2017 "a_scale", "height")))
2018 return failure();
2019 multiplesOfC = aScaleShape.getDimSize(2);
2020 }
2021
2022 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2023 if (bDataShape.hasRank()) {
2024 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2025 "b_data", "batch")) ||
2026 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2027 "b_data", "channels")))
2028 return failure();
2029 W = bDataShape.getDimSize(1);
2030 }
2031
2032 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2033 if (bScaleShape.hasRank()) {
2034 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2035 "b_scale", "batch")) ||
2036 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2037 "b_scale", "width")) ||
2038 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2039 bScaleShape.getDimSize(2), "b_scale",
2040 "C/block_size")))
2041 return failure();
2042 }
2043
2044 // Verify batch size is broadcast compatible
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError("expect B matrix batch size to be broadcast compatible "
2047 "with A, got D=")
2048 << D << " vs N=" << N;
2049
2050 // Verify C is a multiple of block size
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError("expect C to be a multiple of block size, got C=")
2054 << C << ", block_size=" << blockSize;
2055
2056 // Verify multiplesOfC is C / block size
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2059 return emitOpError(
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2062
2063 // Verify output shape
2064 N = ShapedType::isDynamic(N) ? D : N;
2065 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2066 const auto outputType = cast<ShapedType>(getResult().getType());
2067 if (outputType.hasRank() &&
2068 failed(
2069 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2070 InFlightDiagnostic opError = emitOpError("expected output shape ");
2071 auto stringifyDim = [&](int64_t d) {
2072 if (ShapedType::isDynamic(d))
2073 opError << "?";
2074 else
2075 opError << d;
2076 };
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError << " to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2080 return opError;
2081 }
2082
2083 return success();
2084}
2085
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2089 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2091 auto paddingRank =
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2093 SmallVector<int64_t> outputShape;
2094
2095 // If the input rank is unknown, we can infer the output rank using the
2096 // padding shape's rank divided by 2.
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2099 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2100 return success();
2101 }
2102
2103 SmallVector<int64_t> paddingValues;
2104 // If the paddings value is not a constant, all dimensions must be dynamic.
2105 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2106 paddingValues)) {
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2108 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2109 return success();
2110 }
2111
2112 outputShape.reserve(inputShape.getRank());
2113 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2116 continue;
2117 }
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2121 // if either padding for dim i is -1, output dim is unknown
2122 outputShape.push_back(ShapedType::kDynamic);
2123 continue;
2124 }
2125
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2127 }
2128
2129 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2130 return success();
2131}
2132
2133LogicalResult tosa::PadOp::verify() {
2134 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2135 /* outType = */ getOutput().getType())
2136 .failed()) {
2137 return failure();
2138 }
2139
2140 if (auto padConst = getPadConst()) {
2141 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2142 /* outType = */ getOutput().getType())
2143 .failed()) {
2144 return failure();
2145 }
2146 }
2147
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2152 if (!inputType || !outputType)
2153 return success();
2154
2155 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2156 "output")))
2157 return failure();
2158
2159 auto inputRank = inputType.getRank();
2160 DenseIntElementsAttr paddingAttr;
2161 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2162 return success();
2163
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2166 return emitOpError() << "padding tensor must have " << inputRank
2167 << " * 2 = " << inputRank * 2 << " elements, but got "
2168 << paddingValues.size();
2169
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2172
2173 for (int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2176
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2178 return emitOpError()
2179 << "invalid padding values at dimension " << i
2180 << ": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart << ", " << padEnd << "]";
2182 }
2183
2184 // Skip shape verification for dynamic input/output
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2187 continue;
2188
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() << "mismatch in output shape at dimension " << i
2191 << ": expected " << inputShape[i] << " + "
2192 << padStart << " + " << padEnd << " = "
2193 << (inputShape[i] + padStart + padEnd)
2194 << ", but got " << outputShape[i];
2195 }
2196 }
2197
2198 return success();
2199}
2200
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2204 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2205
2206 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2209
2210 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2211 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2213 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2214 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2215 return success();
2216 }
2217
2218 // if size[i] is -1, all remaining elements in dimension i are included
2219 // in the slice, similar to TF.
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2221 // initialize outputShape to all unknown
2222 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2223 if (inputShape.hasRank()) {
2224 for (size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2228 // size[i] is not 0 and not < -1, and start[i] is in valid range
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2230 // input shape has unknown dim[i] - only valid if size[i] > 0
2231 if (size[i] > 0) {
2232 outputShape[i] = size[i];
2233 }
2234 } else {
2235 // input shape has known dim[i]
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2239 // start[i] + size[i] is within bound of input shape's dim[i]
2240 outputShape[i] = size[i];
2241 }
2242 }
2243 }
2244 }
2245 } else {
2246 outputShape = convertToMlirShape(size);
2247 }
2248 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2249 return success();
2250}
2251
2252LogicalResult tosa::SliceOp::verify() {
2253 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2254 /* outType = */ getOutput().getType())
2255 .failed())
2256 return failure();
2257
2258 const ShapeAdaptor inputShape(getInput1().getType());
2259 if (inputShape.hasRank()) {
2260 const auto inputRank = inputShape.getRank();
2261 const ShapeAdaptor outputShape(getOutput().getType());
2262 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2263 return emitOpError(
2264 "expect input1 and output to have the same ranks, got ")
2265 << inputRank << " and " << outputShape.getRank();
2266
2267 const auto startShapeRank =
2268 llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2269 if (inputRank != startShapeRank)
2270 return emitOpError("length of start is not equal to rank of input shape");
2271
2272 const auto sizeShapeRank =
2273 llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2274 if (inputRank != sizeShapeRank)
2275 return emitOpError("length of size is not equal to rank of input shape");
2276 }
2277
2278 return success();
2279}
2280
2281LogicalResult tosa::MulOp::inferReturnTypeComponents(
2282 MLIRContext *context, ::std::optional<Location> location,
2283 ValueShapeRange operands, DictionaryAttr attributes,
2284 OpaqueProperties properties, RegionRange regions,
2285 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2286 // mul op's output shape only depend on input1 and input2, not on shift
2287 ValueShapeRange twoInputs = operands.drop_back();
2289 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2290 inferredReturnShapes.push_back(ShapedTypeComponents());
2291 } else {
2292 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2293 }
2294 return success();
2295}
2296
2297LogicalResult tosa::MulOp::verify() {
2298 const Value output = getOutput();
2299 auto resElemType = getElementTypeOrSelf(output);
2300
2301 // Verify if the element type among operands and result match tosa
2302 // specification.
2303 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2304 IntegerType lhsIntType =
2305 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2306 IntegerType rhsIntType =
2307 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2308 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2309 return emitOpError("requires the same element type for all operands");
2310
2311 // Though the spec requires the element type of result to be i32, a more
2312 // relaxed way is provided at dialect level for easier cooperating with
2313 // other dialects.
2314 if (lhsIntType.getWidth() > resIntType.getWidth())
2315 return emitOpError("invalid data type size for operands or result");
2316
2317 } else {
2318 // For other supported type, the spec requires requires the same element
2319 // type for all operands (excludes `shift` operand) and results.
2320 for (int i = 0; i < 2; ++i) {
2321 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2322 return emitOpError(
2323 "requires the same element type for all operands and results");
2324 }
2325
2326 // verify shift has value 0 for non-integer types
2327 ElementsAttr shiftElem;
2328 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2329 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2330 if (shift != 0) {
2331 return emitOpError() << "require shift to be 0 for float type";
2332 }
2333 }
2334 }
2335
2336 // Verify the op has same ranks for all main operands (excludes extra operands
2337 // such as shift of mul op, so this is the only difference with the built-in
2338 // `SameOperandsAndResultRank` trait) and results types, if known.
2339 TypeRange operandTypes = getOperandTypes();
2340 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2341 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2342
2343 const bool aHasRank = aType.hasRank();
2344 const bool bHasRank = bType.hasRank();
2345 if (aHasRank && bHasRank) {
2346 const int64_t aRank = aType.getRank();
2347 const int64_t bRank = bType.getRank();
2348 if (aRank != bRank)
2349 return emitOpError("a and b operands don't have matching ranks, got ")
2350 << aRank << " and " << bRank;
2351
2352 // check for broadcast compatible shapes
2353 SmallVector<int64_t> resultShape;
2355 aType.getShape(), bType.getShape(), resultShape))
2356 return emitOpError("a and b operands don't have broadcast-compatible "
2357 "shapes, got ")
2358 << aType << " and " << bType;
2359 }
2360
2361 ShapedType resultType = cast<ShapedType>(output.getType());
2362 if (!resultType.hasRank())
2363 return success();
2364
2365 const int64_t resultRank = resultType.getRank();
2366 if (aHasRank && resultRank != aType.getRank())
2367 return emitOpError("result type has different rank than a, got ")
2368 << resultRank << " vs " << aType.getRank();
2369 if (bHasRank && resultRank != bType.getRank())
2370 return emitOpError("result type has different rank than b, got ")
2371 << resultRank << " vs " << bType.getRank();
2372
2373 return success();
2374}
2375
2376LogicalResult tosa::TableOp::inferReturnTypeComponents(
2377 MLIRContext *context, ::std::optional<Location> location,
2378 TableOp::Adaptor adaptor,
2379 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2380 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2381
2382 if (!inputShape.hasRank()) {
2383 inferredReturnShapes.push_back(ShapedTypeComponents());
2384 return success();
2385 }
2386
2387 inferredReturnShapes.resize(1);
2388 inputShape.getDims(inferredReturnShapes[0]);
2389 return success();
2390}
2391
2392LogicalResult tosa::TableOp::verify() {
2393 const TensorType inputType = getInput1().getType();
2394 const TensorType outputType = getOutput().getType();
2395
2396 if (!inputType.hasRank() || !outputType.hasRank())
2397 return success();
2398
2399 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2400 "result")))
2401 return failure();
2402
2403 auto inputDims = inputType.getShape();
2404 auto outputDims = outputType.getShape();
2405 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2406 int64_t dim = it.index();
2407 auto [inputDim, outputDim] = it.value();
2408 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2409 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2410 << " doesn't match dim(input, " << dim
2411 << ") = " << inputDim;
2412 }
2413 }
2414 return success();
2415}
2416
2417LogicalResult
2418tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2419 // Multiples must be constants.
2420 DenseIntElementsAttr multiplesAttr;
2421 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2422 return failure();
2423 multiples =
2424 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2425 [](const APInt &val) { return val.getSExtValue(); });
2426 return success();
2427}
2428
2429LogicalResult tosa::TileOp::inferReturnTypeComponents(
2430 MLIRContext *context, ::std::optional<Location> location,
2431 TileOp::Adaptor adaptor,
2432 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2433 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2434 SmallVector<int64_t> multiples;
2435 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2436 multiples)) {
2437 auto rank =
2438 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2439 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2440 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2441 return success();
2442 } else {
2443 multiples = convertToMlirShape(multiples);
2444 }
2445
2446 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2447 SmallVector<int64_t> outputShape;
2448 if (!inputShape.hasRank()) {
2449 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2450 inferredReturnShapes.push_back(
2451 ShapedTypeComponents(outputShape, inputType));
2452 return success();
2453 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2454 return failure();
2455
2456 // Any non dynamic dimension can be multiplied to a known size.
2457 outputShape.reserve(multiples.size());
2458 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2459 if (multiples[i] == ShapedType::kDynamic) {
2460 outputShape.push_back(ShapedType::kDynamic);
2461 } else {
2462 int64_t dim = inputShape.getDimSize(i);
2463 if (dim != ShapedType::kDynamic)
2464 dim *= multiples[i];
2465 outputShape.push_back(dim);
2466 }
2467 }
2468
2469 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2470 return success();
2471}
2472
2473LogicalResult tosa::TileOp::verify() {
2474 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2475 /* outType = */ getOutput().getType())
2476 .failed()) {
2477 return failure();
2478 }
2479 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2480 ShapedType outputType = llvm::cast<ShapedType>(getType());
2481
2482 shapeType multiplesType =
2483 llvm::cast<tosa::shapeType>(getMultiples().getType());
2484
2485 auto multiplesRank = multiplesType.getRank();
2486
2487 if (inputType.hasRank()) {
2488 if (inputType.getRank() != multiplesRank)
2489 return emitOpError("expect 'multiples' to have rank ")
2490 << inputType.getRank() << " but got " << multiplesRank << ".";
2491 if (outputType.hasRank() &&
2492 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2493 "output")))
2494 return failure();
2495 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2496 return emitOpError("expect 'multiples' array to have length ")
2497 << outputType.getRank() << " but got " << multiplesRank << ".";
2498
2499 SmallVector<int64_t> multiples;
2500 if (getConstantMultiples(multiples).succeeded() &&
2501 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2502 return emitOpError(
2503 "expect element of 'multiples' to be positive integer or -1.");
2504
2505 return success();
2506}
2507
2508bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2509 if (l.size() != r.size() || l.size() != 1)
2510 return false;
2511 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2512}
2513
2514LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2515 MLIRContext *context, ::std::optional<Location> location,
2516 ReshapeOp::Adaptor adaptor,
2517 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2518 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2519 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2520 llvm::SmallVector<int64_t> newShapeValue;
2521 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2522 newShapeValue)) {
2523 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2524 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2525 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2526 return success();
2527 } else {
2528 newShapeValue = convertToMlirShape(newShapeValue);
2529 }
2530
2531 // We cannot infer from the total number of elements so we must take the
2532 // shape attribute as exact.
2533 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2534 inferredReturnShapes.push_back(
2535 ShapedTypeComponents(newShapeValue, inputType));
2536 return success();
2537 }
2538
2539 // Determine the number of elements covered by the slice of all static
2540 // dimensions. This allows us to infer the length of the remaining dynamic
2541 // dimension.
2542 int64_t numElements = inputShape.getNumElements();
2543 int64_t staticMul = 1;
2544 for (auto val : newShapeValue) {
2545 if (ShapedType::isStatic(val)) {
2546 staticMul *= val;
2547 }
2548 }
2549
2550 // Determine the length of the dynamic dimension.
2551 for (auto &val : newShapeValue) {
2552 if (ShapedType::isDynamic(val))
2553 val = numElements / staticMul;
2554 }
2555
2556 inferredReturnShapes.push_back(
2557 ShapedTypeComponents(newShapeValue, inputType));
2558 return success();
2559}
2560
2561llvm::LogicalResult tosa::ReshapeOp::verify() {
2562 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2563 /* outType = */ getOutput().getType())
2564 .failed()) {
2565 return failure();
2566 }
2567 TensorType inputType = getInput1().getType();
2568
2569 SmallVector<int64_t> shapeValues;
2570 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2571 // skip following checks if shape is not constant
2572 return mlir::success();
2573 }
2574
2575 int missingDims = llvm::count(shapeValues, -1);
2576 if (missingDims > 1)
2577 return emitOpError() << "expected at most one target dimension to be -1";
2578
2579 const auto outputType = dyn_cast<RankedTensorType>(getType());
2580 if (!outputType)
2581 return success();
2582
2583 if ((int64_t)shapeValues.size() != outputType.getRank())
2584 return emitOpError() << "new shape does not match result rank";
2585
2586 for (auto [newShapeDim, outputShapeDim] :
2587 zip(shapeValues, outputType.getShape())) {
2588 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2589 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2590 return emitOpError() << "new shape is inconsistent with result shape";
2591
2592 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2593 return emitOpError() << "new shape has invalid tensor dimension size "
2594 << newShapeDim;
2595 }
2596
2597 if (inputType.hasStaticShape()) {
2598 int64_t inputElementsNum = inputType.getNumElements();
2599 if (outputType.hasStaticShape()) {
2600 int64_t outputElementsNum = outputType.getNumElements();
2601 if (inputElementsNum != outputElementsNum) {
2602 return emitOpError() << "cannot reshape " << inputElementsNum
2603 << " elements into " << outputElementsNum;
2604 }
2605 }
2606
2607 int64_t newShapeElementsNum =
2608 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2609 return (dim > 0) ? acc * dim : acc;
2610 });
2611 bool isStaticNewShape =
2612 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2613 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2614 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2615 return emitOpError() << "cannot reshape " << inputElementsNum
2616 << " elements into " << newShapeElementsNum;
2617 }
2618 }
2619
2620 return mlir::success();
2621}
2622
2623// return failure if val is not a constant
2624// set zp to -1 if val is non-zero float or val is not integer nor float
2625// otherwise set zp to val's constant value
2626static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2627 ElementsAttr zpAttr;
2628 if (!matchPattern(val, m_Constant(&zpAttr))) {
2629 return failure();
2630 }
2631
2632 Type zpElemType = zpAttr.getElementType();
2633
2634 if (llvm::isa<FloatType>(zpElemType)) {
2635 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2636 return 0;
2637 }
2638 // return non-zero value to trigger error check
2639 return -1;
2640 }
2641
2642 if (llvm::isa<IntegerType>(zpElemType)) {
2643 if (signExtend)
2644 return zpAttr.getValues<APInt>()[0].getSExtValue();
2645 else
2646 return zpAttr.getValues<APInt>()[0].getZExtValue();
2647 }
2648
2649 // return non-zero value to trigger error check
2650 return -1;
2651}
2652
2653template <typename T>
2654static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2655 const std::string &operand) {
2656 Type zpElemType = getElementTypeOrSelf(val);
2657
2658 if (!zpElemType.isInteger(8) && zp != 0) {
2659 // convert operand to lower case for error message
2660 std::string lower = operand;
2661 llvm::transform(lower, lower.begin(), ::tolower);
2662 return op.emitOpError()
2663 << lower << " zero point must be zero for non-int8 integer types";
2664 }
2665
2666 return success();
2667}
2668
2669static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2670 const int64_t &zp,
2671 const std::string &operand) {
2672 bool isInputZp = (operand == "Input");
2673
2674 bool tensorUnsigned =
2675 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2676 StringRef tensorName = isInputZp ? "input" : "output";
2677
2678 Type zpElemType = getElementTypeOrSelf(zpVal);
2679
2680 if (zp != 0) {
2681 if (!zpElemType.isInteger(8) &&
2682 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2683 return op.emitOpError()
2684 << "expect " << tensorName << "_zp of 0, got " << zp;
2685 }
2686 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2687 return op.emitOpError() << "expect " << tensorName
2688 << "_zp of 0 or 32768 for unsigned int16 "
2689 << tensorName << ", got " << zp;
2690 }
2691 }
2692
2693 return success();
2694}
2695
2696#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2697 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2698 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2699 } \
2700 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2701 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2702 }
2703
2704ZERO_POINT_HELPER(Conv2DOp, Input, true)
2705ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2706ZERO_POINT_HELPER(Conv3DOp, Input, true)
2707ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2708ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2709ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2710ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2711ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2712ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2713ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2714ZERO_POINT_HELPER(MatMulOp, A, true)
2715ZERO_POINT_HELPER(MatMulOp, B, true)
2716ZERO_POINT_HELPER(NegateOp, Input1, true)
2717ZERO_POINT_HELPER(NegateOp, Output, true)
2718ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2719ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2720#undef ZERO_POINT_HELPER
2721
2722LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2723 MLIRContext *context, ::std::optional<Location> location,
2724 TransposeOp::Adaptor adaptor,
2725 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2726 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2727
2728 // If input rank and permutation length is unknown, the output rank is
2729 // unknown.
2730 if (!inputShape.hasRank()) {
2731 inferredReturnShapes.push_back(ShapedTypeComponents());
2732 return success();
2733 }
2734
2735 const auto inputRank = inputShape.getRank();
2736
2737 // This would imply the number of permutations does not match the rank of
2738 // the input which is illegal.
2739 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2740 return failure();
2741 }
2742
2743 SmallVector<int64_t> outputShape;
2744 // Rank-0 means no permutations matter.
2745 if (inputRank == 0) {
2746 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2747 return success();
2748 }
2749
2750 // Check whether the input dimensions are all the same.
2751 bool allTheSame = true;
2752 for (int i = 1, s = inputRank; i < s; i++) {
2753 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2754 allTheSame = false;
2755 break;
2756 }
2757 }
2758
2759 // If all of the input dimensions are the same we don't care about the
2760 // permutation.
2761 if (allTheSame) {
2762 outputShape.resize(inputRank, inputShape.getDimSize(0));
2763 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2764 return success();
2765 }
2766
2767 outputShape.resize(inputRank, ShapedType::kDynamic);
2768
2769 // Constant permutation values must be within the input rank.
2770 if (llvm::any_of(adaptor.getPerms(),
2771 [inputRank](const auto i) { return i >= inputRank; }))
2772 return failure();
2773
2774 outputShape.reserve(inputRank);
2775 for (int i = 0, s = inputRank; i < s; i++) {
2776 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2777 }
2778
2779 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2780 return success();
2781}
2782
2783LogicalResult tosa::TransposeOp::verify() {
2784 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2785 /* outType = */ getOutput().getType())
2786 .failed()) {
2787 return failure();
2788 }
2789
2790 const ShapeAdaptor inputShape(getInput1().getType());
2791 const ShapeAdaptor outputShape(getOutput().getType());
2792
2793 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2794
2795 if (inputShape.hasRank() &&
2796 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2797 return emitOpError() << "expected perms attribute to have size "
2798 << inputShape.getRank()
2799 << " (input rank) but got size "
2800 << constantPerms.size();
2801
2802 if (inputShape.hasRank() && outputShape.hasRank() &&
2803 inputShape.getRank() != outputShape.getRank())
2804 return emitOpError()
2805 << "expected input tensor rank to equal result tensor rank";
2806
2807 if (outputShape.hasRank() &&
2808 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2809 return emitOpError() << "expected perms attribute to have size "
2810 << outputShape.getRank()
2811 << " (output rank) but got size "
2812 << constantPerms.size();
2813
2814 if (!llvm::all_of(constantPerms,
2815 [&constantPerms](int32_t s) {
2816 return s >= 0 &&
2817 static_cast<size_t>(s) < constantPerms.size();
2818 }) ||
2819 !isPermutationVector(llvm::map_to_vector(
2820 constantPerms, [](int32_t v) -> int64_t { return v; })))
2821 return emitOpError() << "expected valid permutation indices";
2822
2823 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2824 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2825 inputShape.getNumElements() != outputShape.getNumElements())
2826 return emitOpError() << "expected input1 and output to have same numbers "
2827 "of elements, got "
2828 << inputShape.getNumElements() << " and "
2829 << outputShape.getNumElements();
2830
2831 // Verify that the types of the input and output tensors are properly
2832 // permuted.
2833 if (inputShape.hasRank() && outputShape.hasRank()) {
2834 for (auto i = 0; i < outputShape.getRank(); i++) {
2835 if (inputShape.isDynamicDim(constantPerms[i]) ||
2836 outputShape.isDynamicDim(i))
2837 continue;
2838
2839 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2840 return emitOpError()
2841 << "expected output tensor dim " << i << " to match "
2842 << "input dim " << constantPerms[i] << " with value of "
2843 << inputShape.getDimSize(constantPerms[i]);
2844 }
2845 }
2846
2847 return success();
2848}
2849
2850LogicalResult TransposeOp::reifyResultShapes(
2851 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2852
2853 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2854
2855 Value input = getInput1();
2856 auto inputType = cast<TensorType>(input.getType());
2857
2858 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2859 for (auto dim : transposePerms) {
2860 int32_t dimInInput = transposePerms[dim];
2861 if (inputType.isDynamicDim(dimInInput))
2862 returnedDims[dim] =
2863 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2864 .getResult();
2865 else
2866 returnedDims[dim] =
2867 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2868 }
2869
2870 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2871 return success();
2872}
2873
2874LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2875 MLIRContext *context, ::std::optional<Location> location,
2876 GatherOp::Adaptor adaptor,
2877 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2878 llvm::SmallVector<int64_t> outputShape;
2879 outputShape.resize(3, ShapedType::kDynamic);
2880
2881 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2882 if (valuesShape.hasRank()) {
2883 outputShape[0] = valuesShape.getDimSize(0);
2884 outputShape[2] = valuesShape.getDimSize(2);
2885 }
2886
2887 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2888 if (indicesShape.hasRank()) {
2889 if (outputShape[0] == ShapedType::kDynamic)
2890 outputShape[0] = indicesShape.getDimSize(0);
2891 if (outputShape[1] == ShapedType::kDynamic)
2892 outputShape[1] = indicesShape.getDimSize(1);
2893 }
2894
2895 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2896 return success();
2897}
2898
2899LogicalResult tosa::GatherOp::verify() {
2900 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2901 /* outType = */ getOutput().getType())
2902 .failed()) {
2903 return failure();
2904 }
2905
2906 const ShapeAdaptor valuesShape(getValues().getType());
2907 const ShapeAdaptor indicesShape(getIndices().getType());
2908 const ShapeAdaptor outputShape(getOutput().getType());
2909
2910 int64_t n = ShapedType::kDynamic;
2911 int64_t w = ShapedType::kDynamic;
2912 int64_t c = ShapedType::kDynamic;
2913
2914 if (valuesShape.hasRank()) {
2915 n = valuesShape.getDimSize(0);
2916 c = valuesShape.getDimSize(2);
2917 }
2918 if (indicesShape.hasRank()) {
2919 const int64_t indicesN = indicesShape.getDimSize(0);
2920 w = indicesShape.getDimSize(1);
2921 if (n == ShapedType::kDynamic)
2922 n = indicesN;
2923 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2924 return emitOpError() << "requires indices dimension 0 to have size " << n
2925 << ", got " << indicesN;
2926 }
2927 if (outputShape.hasRank()) {
2928 const int64_t outputN = outputShape.getDimSize(0);
2929 const int64_t outputW = outputShape.getDimSize(1);
2930 const int64_t outputC = outputShape.getDimSize(2);
2931 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2932 n != outputN)
2933 return emitOpError() << "requires output dimension 0 to have size " << n
2934 << ", got " << outputN;
2935
2936 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2937 w != outputW)
2938 return emitOpError() << "requires output dimension 1 to have size " << w
2939 << ", got " << outputW;
2940 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2941 c != outputC)
2942 return emitOpError() << "requires output dimension 2 to have size " << c
2943 << ", got " << outputC;
2944 }
2945 return success();
2946}
2947
2948LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2949 MLIRContext *context, ::std::optional<Location> location,
2950 ResizeOp::Adaptor adaptor,
2951 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2952 llvm::SmallVector<int64_t, 4> outputShape;
2953 outputShape.resize(4, ShapedType::kDynamic);
2954
2955 ShapeAdaptor inputShape(adaptor.getInput().getType());
2956 if (!inputShape.hasRank())
2957 return failure();
2958
2959 outputShape[0] = inputShape.getDimSize(0);
2960 outputShape[3] = inputShape.getDimSize(3);
2961 int64_t inputHeight = inputShape.getDimSize(1);
2962 int64_t inputWidth = inputShape.getDimSize(2);
2963
2964 if ((inputHeight == ShapedType::kDynamic) ||
2965 (inputWidth == ShapedType::kDynamic))
2966 return failure();
2967
2968 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2969 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2970 scaleInt) ||
2971 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2972 offsetInt) ||
2973 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2974 borderInt)) {
2975 return failure();
2976 }
2977
2978 // Compute the output shape based on attributes: scale, offset, and border.
2979 const int64_t outputHeight =
2980 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2981 scaleInt[1]) +
2982 1;
2983
2984 const int64_t outputWidth =
2985 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2986 scaleInt[3]) +
2987 1;
2988
2989 if (outputHeight < 0 || outputWidth < 0) {
2990 return emitOptionalError(
2991 location,
2992 "calculated output height and width must be non-negative, "
2993 "got height = ",
2994 outputHeight, ", width = ", outputWidth);
2995 }
2996
2997 outputShape[1] = outputHeight;
2998 outputShape[2] = outputWidth;
2999 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3000 return success();
3001}
3002
3003LogicalResult tosa::ResizeOp::verify() {
3004 const Value input = getInput();
3005 const Value output = getOutput();
3006 const RankedTensorType inputType =
3007 llvm::dyn_cast<RankedTensorType>(input.getType());
3008 const RankedTensorType outputType =
3009 llvm::dyn_cast<RankedTensorType>(output.getType());
3010
3011 SmallVector<int64_t> scaleValues;
3012 SmallVector<int64_t> offsetValues;
3013 SmallVector<int64_t> borderValues;
3014 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3015 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3016 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3017 // Skip following checks if shape is not constant
3018 return success();
3019 }
3020
3021 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3022 return emitOpError("expect all scale values to be > 0, got ")
3023 << scaleValues;
3024
3025 const int64_t scaleYN = scaleValues[0];
3026 const int64_t scaleYD = scaleValues[1];
3027 const int64_t scaleXN = scaleValues[2];
3028 const int64_t scaleXD = scaleValues[3];
3029
3030 const int64_t offsetY = offsetValues[0];
3031 const int64_t offsetX = offsetValues[1];
3032
3033 const int64_t borderY = borderValues[0];
3034 const int64_t borderX = borderValues[1];
3035
3036 if (!inputType)
3037 return success();
3038 if (!outputType)
3039 return success();
3040
3041 const int64_t oh = outputType.getDimSize(1);
3042 const int64_t ow = outputType.getDimSize(2);
3043 const int64_t ih = inputType.getDimSize(1);
3044 const int64_t iw = inputType.getDimSize(2);
3045
3046 // Don't check with input height that could be broadcast (ih != 1)
3047 // since Linalg, a consumer of TOSA, expects broadcasting support
3048 // in resize to be available. Taking the cautious approach for now,
3049 // we can consider removing support for broadcasting later.
3050 if (ih != ShapedType::kDynamic && ih != 1) {
3051 const std::optional<int64_t> calculatedOutHeightMinusOne =
3052 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3053 if (!calculatedOutHeightMinusOne.has_value())
3054 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3055 "border_y ")
3056 << "to be wholly divisible by scale_y_d, got ((" << ih
3057 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3058 << ") / " << scaleYD;
3059 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3060 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3061 return emitOpError("calculated output height did not match expected: ")
3062 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3063 }
3064
3065 // Don't check with input width that could be broadcast (iw != 1)
3066 // since Linalg, a consumer of TOSA, expects broadcasting support
3067 // in resize to be available. Taking the cautious approach for now,
3068 // we can consider removing support for broadcasting later.
3069 if (iw != ShapedType::kDynamic && iw != 1) {
3070 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3071 const std::optional<int64_t> calculatedOutWidthMinusOne =
3072 idivCheck(scaledInWidth, scaleXD);
3073 if (!calculatedOutWidthMinusOne.has_value())
3074 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3075 "border_x ")
3076 << "to be wholly divisible by scale_x_d, got ((" << iw
3077 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3078 << ") / " << scaleXD;
3079 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3080 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3081 return emitOpError("calculated output width did not match expected: ")
3082 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3083 }
3084
3085 return success();
3086}
3087
3088LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3089 MLIRContext *context, ::std::optional<Location> location,
3090 ScatterOp::Adaptor adaptor,
3091 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3092 llvm::SmallVector<int64_t> outputShape;
3093 outputShape.resize(3, ShapedType::kDynamic);
3094
3095 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3096 if (valuesInShape.hasRank()) {
3097 outputShape[0] = valuesInShape.getDimSize(0);
3098 outputShape[1] = valuesInShape.getDimSize(1);
3099 outputShape[2] = valuesInShape.getDimSize(2);
3100 }
3101
3102 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3103 if (indicesShape.hasRank()) {
3104 if (outputShape[0] == ShapedType::kDynamic)
3105 outputShape[0] = indicesShape.getDimSize(0);
3106 }
3107
3108 ShapeAdaptor inputShape(adaptor.getInput().getType());
3109 if (inputShape.hasRank()) {
3110 if (outputShape[0] == ShapedType::kDynamic)
3111 outputShape[0] = inputShape.getDimSize(0);
3112 if (outputShape[2] == ShapedType::kDynamic)
3113 outputShape[2] = inputShape.getDimSize(2);
3114 }
3115
3116 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3117 return success();
3118}
3119
3120LogicalResult tosa::ScatterOp::verify() {
3121 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3122 /* outType = */ getValuesOut().getType())
3123 .failed() ||
3124 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3125 /* outType = */ getValuesOut().getType())
3126 .failed()) {
3127 return failure();
3128 }
3129
3130 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3131 const ShapeAdaptor indicesShape(getIndices().getType());
3132 const ShapeAdaptor inputShape(getInput().getType());
3133 const ShapeAdaptor outputShape(getValuesOut().getType());
3134
3135 int64_t n = ShapedType::kDynamic;
3136 int64_t k = ShapedType::kDynamic;
3137 int64_t w = ShapedType::kDynamic;
3138 int64_t c = ShapedType::kDynamic;
3139 if (valuesInShape.hasRank()) {
3140 n = valuesInShape.getDimSize(0);
3141 k = valuesInShape.getDimSize(1);
3142 c = valuesInShape.getDimSize(2);
3143 }
3144 if (indicesShape.hasRank()) {
3145 const int64_t indicesN = indicesShape.getDimSize(0);
3146 w = indicesShape.getDimSize(1);
3147 if (n == ShapedType::kDynamic)
3148 n = indicesN;
3149 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3150 return emitOpError() << "requires indices dimension 0 to have size " << n
3151 << ", got " << indicesN;
3152 }
3153 if (inputShape.hasRank()) {
3154 const int64_t inputN = inputShape.getDimSize(0);
3155 const int64_t inputW = inputShape.getDimSize(1);
3156 const int64_t inputC = inputShape.getDimSize(2);
3157 if (n == ShapedType::kDynamic)
3158 n = inputN;
3159 else if (inputN != ShapedType::kDynamic && n != inputN)
3160 return emitOpError() << "requires input dimension 0 to have size " << n
3161 << ", got " << inputN;
3162 if (w == ShapedType::kDynamic)
3163 w = inputW;
3164 else if (inputW != ShapedType::kDynamic && w != inputW)
3165 return emitOpError() << "requires input dimension 1 to have size " << w
3166 << ", got " << inputW;
3167
3168 if (c == ShapedType::kDynamic)
3169 c = inputC;
3170 else if (inputC != ShapedType::kDynamic && c != inputC)
3171 return emitOpError() << "requires input dimension 2 to have size " << c
3172 << ", got " << inputC;
3173 }
3174 if (outputShape.hasRank()) {
3175 const int64_t outputN = outputShape.getDimSize(0);
3176 const int64_t outputK = outputShape.getDimSize(1);
3177 const int64_t outputC = outputShape.getDimSize(2);
3178 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3179 n != outputN)
3180 return emitOpError() << "requires values_out dimension 0 to have size "
3181 << n << ", got " << outputN;
3182 if (k == ShapedType::kDynamic)
3183 k = outputK;
3184 else if (outputK != ShapedType::kDynamic && k != outputK)
3185 return emitOpError() << "requires values_out dimension 1 to have size "
3186 << k << ", got " << outputK;
3187 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3188 c != outputC)
3189 return emitOpError() << "requires values_out dimension 2 to have size "
3190 << c << ", got " << outputC;
3191 }
3192 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3193 return emitOpError() << "requires dimensions K >= W, got K=" << k
3194 << " and W=" << w;
3195
3196 return success();
3197}
3198
3199static LogicalResult ReduceInferReturnTypes(
3200 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3201 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3202 int64_t axisVal = axis.getValue().getSExtValue();
3203 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3204 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3205 return success();
3206 }
3207
3208 SmallVector<int64_t> outputShape;
3209 operandShape.getDims(outputShape);
3210 outputShape[axisVal] = 1;
3211 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3212 return success();
3213}
3214
3215#define COMPATIBLE_RETURN_TYPES(OP) \
3216 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3217 if (l.size() != r.size() || l.size() != 1) \
3218 return false; \
3219 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3220 return false; \
3221 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3222 }
3223
3224#define REDUCE_SHAPE_INFER(OP) \
3225 LogicalResult OP::inferReturnTypeComponents( \
3226 MLIRContext *context, ::std::optional<Location> location, \
3227 OP::Adaptor adaptor, \
3228 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3229 Type inputType = \
3230 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3231 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3232 const Properties &prop = adaptor.getProperties(); \
3233 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3234 inferredReturnShapes); \
3235 } \
3236 COMPATIBLE_RETURN_TYPES(OP)
3237
3238REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3239REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3240REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3241REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3242REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3243REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3244#undef REDUCE_SHAPE_INFER
3245COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3246#undef COMPATIBLE_RETURN_TYPES
3247
3248template <typename T>
3249static LogicalResult verifyReduceOp(T op) {
3250 // All TOSA reduce Ops have input, output and axis.
3251 TensorType inputType = op.getInput().getType();
3252 TensorType outputType = op.getOutput().getType();
3253 int32_t reduceAxis = op.getAxis();
3254
3255 if (reduceAxis < 0) {
3256 op.emitOpError("reduce axis must not be negative");
3257 return failure();
3258 }
3259 if (inputType.hasRank()) {
3260 int64_t inputRank = inputType.getRank();
3261 // We allow for a special case where the input/output shape has rank 0 and
3262 // axis is also 0.
3263 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3264 op.emitOpError("expect input tensor rank (")
3265 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3266 << ")";
3267 return failure();
3268 }
3269 }
3270 if (outputType.hasRank()) {
3271 int64_t outputRank = outputType.getRank();
3272 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3273 op.emitOpError(
3274 "expect output tensor rank to be equal to input tensor rank");
3275 return failure();
3276 }
3277 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3278 op.emitOpError("expect output tensor rank (")
3279 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3280 << ")";
3281 return failure();
3282 }
3283 // We can only verify the reduced dimension size to be 1 if this is not
3284 // the special case of output rank == 0.
3285 if (outputRank != 0) {
3286 auto outputShape = outputType.getShape();
3287 if (!outputType.isDynamicDim(reduceAxis) &&
3288 outputShape[reduceAxis] != 1) {
3289 op.emitOpError("expect reduced dimension size to be 1, got ")
3290 << outputShape[reduceAxis];
3291 return failure();
3292 }
3293 }
3294 }
3295 return success();
3296}
3297
3298LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3299LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3300LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3301LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3302LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3303LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3304
3305static LogicalResult NAryInferReturnTypes(
3306 const ValueShapeRange &operands,
3307 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3309 if (resolveBroadcastShape(operands, outShape).failed()) {
3310 inferredReturnShapes.push_back(ShapedTypeComponents());
3311 } else {
3312 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3313 }
3314 return success();
3315}
3316
3317#define NARY_SHAPE_INFER(OP) \
3318 LogicalResult OP::inferReturnTypeComponents( \
3319 MLIRContext *context, ::std::optional<Location> location, \
3320 ValueShapeRange operands, DictionaryAttr attributes, \
3321 OpaqueProperties properties, RegionRange regions, \
3322 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3323 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3324 }
3325
3326NARY_SHAPE_INFER(tosa::AbsOp)
3327NARY_SHAPE_INFER(tosa::AddOp)
3328NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3329NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3330NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3331NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3332NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3333NARY_SHAPE_INFER(tosa::CastOp)
3334NARY_SHAPE_INFER(tosa::CeilOp)
3335NARY_SHAPE_INFER(tosa::ClampOp)
3336NARY_SHAPE_INFER(tosa::ClzOp)
3337NARY_SHAPE_INFER(tosa::CosOp)
3338NARY_SHAPE_INFER(tosa::ExpOp)
3339NARY_SHAPE_INFER(tosa::FloorOp)
3340NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3341NARY_SHAPE_INFER(tosa::GreaterOp)
3342NARY_SHAPE_INFER(tosa::IdentityOp)
3343NARY_SHAPE_INFER(tosa::IntDivOp)
3344NARY_SHAPE_INFER(tosa::LogOp)
3345NARY_SHAPE_INFER(tosa::LogicalAndOp)
3346NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3347NARY_SHAPE_INFER(tosa::LogicalNotOp)
3348NARY_SHAPE_INFER(tosa::LogicalOrOp)
3349NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3350NARY_SHAPE_INFER(tosa::LogicalXorOp)
3351NARY_SHAPE_INFER(tosa::MaximumOp)
3352NARY_SHAPE_INFER(tosa::MinimumOp)
3353NARY_SHAPE_INFER(tosa::PowOp)
3354NARY_SHAPE_INFER(tosa::ReciprocalOp)
3355NARY_SHAPE_INFER(tosa::ReverseOp)
3356NARY_SHAPE_INFER(tosa::RsqrtOp)
3357NARY_SHAPE_INFER(tosa::SinOp)
3358NARY_SHAPE_INFER(tosa::SelectOp)
3359NARY_SHAPE_INFER(tosa::SubOp)
3360NARY_SHAPE_INFER(tosa::TanhOp)
3361NARY_SHAPE_INFER(tosa::ErfOp)
3362NARY_SHAPE_INFER(tosa::SigmoidOp)
3363#undef PRED_SHAPE_INFER
3364
3365LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3366 MLIRContext *context, ::std::optional<Location> location,
3367 NegateOp::Adaptor adaptor,
3368 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3369 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3370 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3371 return success();
3372}
3373
3374LogicalResult tosa::NegateOp::verify() {
3375 // Verify same element type
3376 const Type input1Type = getInput1().getType();
3377 const Type outputType = getOutput().getType();
3378 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3379 return failure();
3380
3381 // Verify same shape
3382 const SmallVector<Type, 2> types = {input1Type, outputType};
3383 if (failed(verifyCompatibleShapes(types)))
3384 return emitOpError() << "requires the same shape for input1 and output";
3385
3386 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3387 const Type input1ZpEType =
3388 getStorageElementTypeOrSelf(getInput1Zp().getType());
3389 if (input1EType != input1ZpEType) {
3390 return emitOpError("expect both input1 and its zero point are the same "
3391 "element type, got ")
3392 << input1EType << " and " << input1ZpEType;
3393 }
3394 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3395 const Type outputZpEType =
3396 getStorageElementTypeOrSelf(getOutputZp().getType());
3397 if (outputEType != outputZpEType) {
3398 return emitOpError("expect both output and its zero point are the same "
3399 "element type, got ")
3400 << outputEType << " and " << outputZpEType;
3401 }
3402
3403 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3404 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3405 return failure();
3406
3407 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3408 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3409 return failure();
3410
3411 return success();
3412}
3413
3414static LogicalResult poolingInferReturnTypes(
3415 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3417 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3418 llvm::SmallVector<int64_t> outputShape;
3419 outputShape.resize(4, ShapedType::kDynamic);
3420
3421 // We only know the rank if the input type is unranked.
3422 if (!inputShape.hasRank()) {
3423 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3424 return success();
3425 }
3426
3427 // Batch and number of channels are identical for pooling layer.
3428 outputShape[0] = inputShape.getDimSize(0);
3429 outputShape[3] = inputShape.getDimSize(3);
3430
3431 int64_t height = inputShape.getDimSize(1);
3432 int64_t width = inputShape.getDimSize(2);
3433
3434 if (ShapedType::isStatic(height)) {
3435 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3436 outputShape[1] = padded / stride[0] + 1;
3437 }
3438
3439 if (ShapedType::isStatic(width)) {
3440 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3441 outputShape[2] = padded / stride[1] + 1;
3442 }
3443
3444 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3445 return success();
3446}
3447
3448LogicalResult Conv2DOp::inferReturnTypeComponents(
3449 MLIRContext *context, ::std::optional<Location> location,
3450 Conv2DOp::Adaptor adaptor,
3451 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3452 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3453
3454 int64_t inputWidth = ShapedType::kDynamic;
3455 int64_t inputHeight = ShapedType::kDynamic;
3456 int64_t weightWidth = ShapedType::kDynamic;
3457 int64_t weightHeight = ShapedType::kDynamic;
3458
3459 // Input shape describes input width/height and batch.
3460
3461 ShapeAdaptor inputShape(adaptor.getInput().getType());
3462 if (inputShape.hasRank()) {
3463 outputShape[0] = inputShape.getDimSize(0);
3464 inputHeight = inputShape.getDimSize(1);
3465 inputWidth = inputShape.getDimSize(2);
3466 }
3467
3468 // Weight shapes describes the filter width/height and the output channels.
3469 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3470 if (weightShape.hasRank()) {
3471 outputShape[3] = weightShape.getDimSize(0);
3472 weightHeight = weightShape.getDimSize(1);
3473 weightWidth = weightShape.getDimSize(2);
3474 }
3475
3476 // Bias shape can describe the output channels.
3477 ShapeAdaptor biasShape(adaptor.getBias().getType());
3478 if (biasShape.hasRank()) {
3479 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3480 ? biasShape.getDimSize(0)
3481 : outputShape[3];
3482 }
3483
3484 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3485 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3486 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3487
3488 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3489 int64_t inputSize = inputHeight + padding[0] + padding[1];
3490 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3491 int64_t unstridedResult = inputSize - filterSize + 1;
3492 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3493 }
3494
3495 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3496 int64_t inputSize = inputWidth + padding[2] + padding[3];
3497 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3498 int64_t unstridedResult = inputSize - filterSize + 1;
3499 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3500 }
3501
3502 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3503 return success();
3504}
3505
3506LogicalResult Conv2DOp::verify() {
3507 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3508 verifyConvOpErrorIf(*this).failed())
3509 return failure();
3510 return success();
3511}
3512
3513LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3514 MLIRContext *context, ::std::optional<Location> location,
3515 Conv2DBlockScaledOp::Adaptor adaptor,
3516 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3517 SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
3518
3519 int64_t inputWidth = ShapedType::kDynamic;
3520 int64_t inputHeight = ShapedType::kDynamic;
3521 int64_t weightWidth = ShapedType::kDynamic;
3522 int64_t weightHeight = ShapedType::kDynamic;
3523
3524 // Input shape describes input width/height and batch.
3525 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3526 if (inputDataShape.hasRank()) {
3527 outShape[0] = inputDataShape.getDimSize(0);
3528 inputHeight = inputDataShape.getDimSize(1);
3529 inputWidth = inputDataShape.getDimSize(2);
3530 }
3531 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3532 if (inputScaleShape.hasRank()) {
3533 outShape[0] = ShapedType::isDynamic(outShape[0])
3534 ? inputScaleShape.getDimSize(0)
3535 : outShape[0];
3536 inputHeight = ShapedType::isDynamic(inputHeight)
3537 ? inputScaleShape.getDimSize(1)
3538 : inputHeight;
3539 inputWidth = ShapedType::isDynamic(inputWidth)
3540 ? inputScaleShape.getDimSize(2)
3541 : inputWidth;
3542 }
3543
3544 // Weight shapes describes the filter width/height and the output channels.
3545 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3546 if (weightDataShape.hasRank()) {
3547 outShape[3] = weightDataShape.getDimSize(0);
3548 weightHeight = weightDataShape.getDimSize(1);
3549 weightWidth = weightDataShape.getDimSize(2);
3550 }
3551 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3552 if (weightScaleShape.hasRank()) {
3553 outShape[3] = ShapedType::isDynamic(outShape[3])
3554 ? weightScaleShape.getDimSize(0)
3555 : outShape[3];
3556 weightHeight = ShapedType::isDynamic(weightHeight)
3557 ? weightScaleShape.getDimSize(1)
3558 : weightHeight;
3559 weightWidth = ShapedType::isDynamic(weightWidth)
3560 ? weightScaleShape.getDimSize(2)
3561 : weightWidth;
3562 }
3563
3564 // Bias shape can describe the output channels.
3565 const ShapeAdaptor biasShape(adaptor.getBias().getType());
3566 if (biasShape.hasRank()) {
3567 const int64_t biasSize = biasShape.getDimSize(0);
3568 // Bias of size 1 may be broadcast
3569 if (biasSize != 1) {
3570 outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
3571 }
3572 }
3573
3574 SmallVector<int64_t> padValues;
3575 SmallVector<int64_t> strideValues;
3576 SmallVector<int64_t> dilationValues;
3577 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
3578 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
3579 strideValues) ||
3580 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
3581 dilationValues)) {
3582 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3583 return success();
3584 }
3585
3586 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3587 const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
3588 const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
3589 const int64_t unstridedResult = inputSize - filterSize + 1;
3590 outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
3591 }
3592
3593 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3594 const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
3595 const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
3596 const int64_t unstridedResult = inputSize - filterSize + 1;
3597 outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
3598 }
3599
3600 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3601 return success();
3602}
3603
3604LogicalResult Conv2DBlockScaledOp::verify() {
3605 if (failed(verifySameElementTypes(*this, getInputData().getType(),
3606 getWeightData().getType(), "input_data",
3607 "weight_data")) ||
3608 failed(verifySameElementTypes(*this, getInputScale().getType(),
3609 getWeightScale().getType(), "input_scale",
3610 "weight_scale")) ||
3611 failed(verifySameElementTypes(*this, getBias().getType(),
3612 getOutput().getType(), "bias", "output")))
3613 return failure();
3614
3615 // Verify input shape compatibility
3616 int64_t N = ShapedType::kDynamic;
3617 int64_t IH = ShapedType::kDynamic;
3618 int64_t IW = ShapedType::kDynamic;
3619 int64_t IC = ShapedType::kDynamic;
3620 int64_t multiplesOfIC = ShapedType::kDynamic;
3621 int64_t OC = ShapedType::kDynamic;
3622 int64_t KH = ShapedType::kDynamic;
3623 int64_t KW = ShapedType::kDynamic;
3624
3625 const ShapeAdaptor inputDataShape(getInputData().getType());
3626 if (inputDataShape.hasRank()) {
3627 N = inputDataShape.getDimSize(0);
3628 IH = inputDataShape.getDimSize(1);
3629 IW = inputDataShape.getDimSize(2);
3630 IC = inputDataShape.getDimSize(3);
3631 }
3632
3633 const ShapeAdaptor inputScaleShape(getInputScale().getType());
3634 if (inputScaleShape.hasRank()) {
3635 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
3636 "input_scale", "batch size")) ||
3637 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
3638 "input_scale", "input height")) ||
3639 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
3640 "input_scale", "input width")))
3641 return failure();
3642 multiplesOfIC = inputScaleShape.getDimSize(3);
3643 }
3644
3645 const ShapeAdaptor weightDataShape(getWeightData().getType());
3646 if (weightDataShape.hasRank()) {
3647 OC = weightDataShape.getDimSize(0);
3648 KH = weightDataShape.getDimSize(1);
3649 KW = weightDataShape.getDimSize(2);
3650 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
3651 "weight_data", "input channels")))
3652 return failure();
3653 }
3654
3655 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
3656 if (weightScaleShape.hasRank()) {
3657 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
3658 "weight_scale", "output channels")) ||
3659 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
3660 "weight_scale", "kernel height")) ||
3661 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
3662 "weight_scale", "kernel width")) ||
3663 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
3664 weightScaleShape.getDimSize(3),
3665 "weight_scale", "input channel blocks")))
3666 return failure();
3667 }
3668
3669 // Verify IC is a multiple of block size
3670 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3671 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3672 return emitOpError("expect IC to be a multiple of block size, got IC=")
3673 << IC << ", block_size=" << blockSize;
3674
3675 // Verify multiplesOfIC is IC / block size
3676 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3677 multiplesOfIC != IC / blockSize)
3678 return emitOpError(
3679 "expect scale operands dimension 2 to equal IC/block_size (")
3680 << IC << "/" << blockSize << ")"
3681 << ", got " << multiplesOfIC;
3682
3683 // Verify pad/stride/dilation values
3684 SmallVector<int64_t> padValues;
3685 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
3686 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
3687 return emitOpError("expect all padding values to be >= 0, got ")
3688 << padValues;
3689 }
3690
3691 SmallVector<int64_t> strideValues;
3692 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
3693 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
3694 return emitOpError("expect all stride values to be >= 1, got ")
3695 << strideValues;
3696 }
3697
3698 SmallVector<int64_t> dilationValues;
3699 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
3700 dilationValues)) {
3701 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
3702 return emitOpError("expect all dilation values to be >= 1, got ")
3703 << dilationValues;
3704 }
3705
3706 // Verify output shape compatibility
3707 const ShapeAdaptor outputShape(getOutput().getType());
3708 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3709 outputShape.hasRank()) {
3710 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
3711 padValues[0], padValues[1], strideValues[0],
3712 dilationValues[0], "height", "y", "top",
3713 "bottom")) ||
3714 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
3715 padValues[2], padValues[3], strideValues[1],
3716 dilationValues[1], "width", "x", "left",
3717 "right")))
3718 return failure();
3719 }
3720
3721 // Verify bias
3722 const ShapeAdaptor biasShape(getBias().getType());
3723 if (biasShape.hasRank() && outputShape.hasRank()) {
3724 const int64_t biasChannels = biasShape.getDimSize(0);
3725 const int64_t outputChannels =
3726 outputShape.getDimSize(outputShape.getRank() - 1);
3727 if (biasChannels == ShapedType::kDynamic ||
3728 outputChannels == ShapedType::kDynamic)
3729 // Skip following checks if biasChannels or outputChannels is dynamic dim
3730 return success();
3731
3732 if (biasChannels != outputChannels && biasChannels != 1)
3733 return emitOpError(
3734 "bias channels expected to be equal to output channels (")
3735 << outputChannels << ") or 1, got " << biasChannels;
3736 }
3737
3738 return success();
3739}
3740
3741LogicalResult Conv3DOp::inferReturnTypeComponents(
3742 MLIRContext *context, ::std::optional<Location> location,
3743 Conv3DOp::Adaptor adaptor,
3744 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3745 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3746
3747 int64_t inputWidth = ShapedType::kDynamic;
3748 int64_t inputHeight = ShapedType::kDynamic;
3749 int64_t inputDepth = ShapedType::kDynamic;
3750
3751 int64_t weightWidth = ShapedType::kDynamic;
3752 int64_t weightHeight = ShapedType::kDynamic;
3753 int64_t weightDepth = ShapedType::kDynamic;
3754
3755 // Input shape describes input width/height and batch.
3756 ShapeAdaptor inputShape(adaptor.getInput().getType());
3757 if (inputShape.hasRank()) {
3758 outputShape[0] = inputShape.getDimSize(0);
3759 inputDepth = inputShape.getDimSize(1);
3760 inputHeight = inputShape.getDimSize(2);
3761 inputWidth = inputShape.getDimSize(3);
3762 }
3763
3764 // Weight shapes describes the filter width/height and the output channels.
3765 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3766 if (weightShape.hasRank()) {
3767 outputShape[4] = weightShape.getDimSize(0);
3768 weightDepth = weightShape.getDimSize(1);
3769 weightHeight = weightShape.getDimSize(2);
3770 weightWidth = weightShape.getDimSize(3);
3771 }
3772
3773 // Bias shape can describe the output channels.
3774 ShapeAdaptor biasShape(adaptor.getBias().getType());
3775 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3776 outputShape[4] = biasShape.getDimSize(0);
3777 }
3778
3779 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3780 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3781 llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3782
3783 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3784 int32_t inputSize = inputDepth + pad[0] + pad[1];
3785 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3786 int32_t unstridedResult = inputSize - filterSize + 1;
3787 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3788 }
3789
3790 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3791 int32_t inputSize = inputHeight + pad[2] + pad[3];
3792 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3793 int32_t unstridedResult = inputSize - filterSize + 1;
3794 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3795 }
3796
3797 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3798 int32_t inputSize = inputWidth + pad[4] + pad[5];
3799 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3800 int32_t unstridedResult = inputSize - filterSize + 1;
3801 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3802 }
3803
3804 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3805 return success();
3806}
3807
3808LogicalResult Conv3DOp::verify() {
3809 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3810 verifyConvOpErrorIf(*this).failed())
3811 return failure();
3812 return success();
3813}
3814
3815LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3816 MLIRContext *context, ::std::optional<Location> location,
3817 AvgPool2dOp::Adaptor adaptor,
3818 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3819 ShapeAdaptor inputShape(adaptor.getInput().getType());
3820 const Properties &prop = adaptor.getProperties();
3821 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3822 inferredReturnShapes);
3823}
3824
3825LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3826 MLIRContext *context, ::std::optional<Location> location,
3827 MaxPool2dOp::Adaptor adaptor,
3828 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3829 ShapeAdaptor inputShape(adaptor.getInput().getType());
3830 const Properties &prop = adaptor.getProperties();
3831 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3832 inferredReturnShapes);
3833}
3834
3835LogicalResult MaxPool2dOp::verify() {
3836 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3837 /* outType = */ getOutput().getType())))
3838 return failure();
3839
3840 if (failed(verifyPoolingOp(*this)))
3841 return failure();
3842
3843 return success();
3844}
3845
3846LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3847 MLIRContext *context, ::std::optional<Location> location,
3848 DepthwiseConv2DOp::Adaptor adaptor,
3849 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3850 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3851
3852 int64_t inputWidth = ShapedType::kDynamic;
3853 int64_t inputHeight = ShapedType::kDynamic;
3854 int64_t inputChannels = ShapedType::kDynamic;
3855
3856 int64_t weightWidth = ShapedType::kDynamic;
3857 int64_t weightHeight = ShapedType::kDynamic;
3858 int64_t depthChannels = ShapedType::kDynamic;
3859
3860 // Input shape describes input width/height and batch.
3861 ShapeAdaptor inputShape(adaptor.getInput().getType());
3862 if (inputShape.hasRank()) {
3863 outputShape[0] = inputShape.getDimSize(0);
3864 inputHeight = inputShape.getDimSize(1);
3865 inputWidth = inputShape.getDimSize(2);
3866 inputChannels = inputShape.getDimSize(3);
3867 }
3868
3869 // Weight shapes describes the filter width/height and the output channels.
3870 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3871 if (weightShape.hasRank()) {
3872 weightHeight = weightShape.getDimSize(0);
3873 weightWidth = weightShape.getDimSize(1);
3874 inputChannels = ShapedType::isDynamic(inputChannels)
3875 ? weightShape.getDimSize(2)
3876 : inputChannels;
3877 depthChannels = weightShape.getDimSize(3);
3878 }
3879
3880 // If both inputChannels and depthChannels are available we can determine
3881 // the output channels.
3882 if (ShapedType::isStatic(inputChannels) &&
3883 ShapedType::isStatic(depthChannels)) {
3884 outputShape[3] = inputChannels * depthChannels;
3885 }
3886
3887 // Bias shape can describe the output channels.
3888 ShapeAdaptor biasShape(adaptor.getBias().getType());
3889 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3890 int64_t bc = biasShape.getDimSize(0);
3891 if (bc != ShapedType::kDynamic && bc != 1)
3892 outputShape[3] = bc;
3893 }
3894
3895 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3896 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3897 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3898
3899 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3900 int64_t inputSize = inputHeight + padding[0] + padding[1];
3901 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3902 int64_t unstridedResult = inputSize - filterSize + 1;
3903 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3904 }
3905
3906 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3907 int64_t inputSize = inputWidth + padding[2] + padding[3];
3908 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3909 int64_t unstridedResult = inputSize - filterSize + 1;
3910 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3911 }
3912
3913 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3914 return success();
3915}
3916
3917LogicalResult DepthwiseConv2DOp::verify() {
3918 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3919 verifyConvOpErrorIf(*this).failed())
3920 return failure();
3921 return success();
3922}
3923
3924LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3925 MLIRContext *context, ::std::optional<Location> location,
3926 TransposeConv2DOp::Adaptor adaptor,
3927 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3928 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3929
3930 int64_t inputWidth = ShapedType::kDynamic;
3931 int64_t inputHeight = ShapedType::kDynamic;
3932 int64_t weightWidth = ShapedType::kDynamic;
3933 int64_t weightHeight = ShapedType::kDynamic;
3934
3935 // Input shape describes input width/height and batch.
3936 ShapeAdaptor inputShape(adaptor.getInput().getType());
3937 if (inputShape.hasRank()) {
3938 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3939 ? inputShape.getDimSize(0)
3940 : outputShape[0];
3941 inputHeight = inputShape.getDimSize(1);
3942 inputWidth = inputShape.getDimSize(2);
3943 }
3944
3945 // Weight shapes describes the filter width/height and the output channels.
3946 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3947 if (weightShape.hasRank()) {
3948 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3949 ? weightShape.getDimSize(0)
3950 : outputShape[3];
3951 weightHeight = weightShape.getDimSize(1);
3952 weightWidth = weightShape.getDimSize(2);
3953 }
3954
3955 // Bias shape can describe the output channels.
3956 ShapeAdaptor biasShape(adaptor.getBias().getType());
3957 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
3958 int64_t bc = biasShape.getDimSize(0);
3959 if (bc != ShapedType::kDynamic && bc != 1)
3960 outputShape[3] = bc;
3961 }
3962
3963 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3964 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3965
3966 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3967 int64_t calculateSize =
3968 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3969 outputShape[1] =
3970 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3971 }
3972
3973 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3974 int64_t calculateSize =
3975 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3976 outputShape[2] =
3977 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3978 }
3979
3980 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3981 return success();
3982}
3983
3984LogicalResult TransposeConv2DOp::verify() {
3985 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3986 return failure();
3987
3988 const llvm::ArrayRef<int64_t> strides = getStride();
3989 const int64_t strideY = strides[0];
3990 const int64_t strideX = strides[1];
3991
3992 if (strideY < 1 || strideX < 1)
3993 return emitOpError("expect all stride values to be >= 1, got [")
3994 << strides << "]";
3995
3996 const auto checkPadAgainstKernelDim =
3997 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
3998 llvm::StringRef kernelDimName) -> LogicalResult {
3999 if (padValue <= -kernelDimSize)
4000 return emitOpError("expected ")
4001 << padName << " > -" << kernelDimName << ", but got: " << padName
4002 << "=" << padValue << " and " << kernelDimName << "="
4003 << kernelDimSize;
4004 return success();
4005 };
4006
4007 const llvm::ArrayRef<int64_t> padding = getOutPad();
4008 const int64_t outPadTop = padding[0];
4009 const int64_t outPadBottom = padding[1];
4010 const int64_t outPadLeft = padding[2];
4011 const int64_t outPadRight = padding[3];
4012
4013 const auto weightType =
4014 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4015
4016 if (weightType) {
4017 const int64_t kernelHeight = weightType.getDimSize(1);
4018 if (ShapedType::isStatic(kernelHeight)) {
4019 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4020 "out_pad_top", "KH")))
4021 return failure();
4022
4023 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4024 "out_pad_bottom", "KH")))
4025 return failure();
4026 }
4027
4028 const int64_t kernelWidth = weightType.getDimSize(2);
4029 if (ShapedType::isStatic(kernelWidth)) {
4030 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4031 "out_pad_left", "KW")))
4032 return failure();
4033
4034 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4035 "out_pad_right", "KW")))
4036 return failure();
4037 }
4038 }
4039
4040 // Rest of the checks depend on the output type being a RankedTensorType
4041 const auto outputType =
4042 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4043 if (!outputType)
4044 return success();
4045
4046 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4047 if (inputType && weightType) {
4048 const int64_t inputHeight = inputType.getDimSize(1);
4049 const int64_t kernelHeight = weightType.getDimSize(1);
4050 const int64_t outputHeight = outputType.getDimSize(1);
4051
4052 if (ShapedType::isStatic(inputHeight) &&
4053 ShapedType::isStatic(outputHeight)) {
4054 if (outputHeight !=
4055 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4056 return emitOpError(
4057 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4058 "+ out_pad_top + out_pad_bottom + KH, but got ")
4059 << outputHeight << " != (" << inputHeight << " - 1) * "
4060 << strideY << " + " << outPadTop << " + " << outPadBottom
4061 << " + " << kernelHeight;
4062 }
4063
4064 const int64_t inputWidth = inputType.getDimSize(2);
4065 const int64_t kernelWidth = weightType.getDimSize(2);
4066 const int64_t outputWidth = outputType.getDimSize(2);
4067
4068 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4069 if (outputWidth !=
4070 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4071 return emitOpError(
4072 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4073 "+ out_pad_left + out_pad_right + KW, but got ")
4074 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4075 << " + " << outPadLeft << " + " << outPadRight << " + "
4076 << kernelWidth;
4077 }
4078 }
4079
4080 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4081
4082 if (!biasType)
4083 return success();
4084
4085 const int64_t biasChannels = biasType.getDimSize(0);
4086
4087 // Skip further checks if bias is dynamic
4088 if (biasChannels == ShapedType::kDynamic)
4089 return success();
4090
4091 const int64_t outputChannels = outputType.getDimSize(3);
4092 if (!ShapedType::isDynamic(outputChannels) &&
4093 biasChannels != outputChannels && biasChannels != 1)
4094 return emitOpError(
4095 "bias channels expected to be equal to output channels (")
4096 << outputChannels << ") or 1, got " << biasChannels;
4097
4098 return success();
4099}
4100
4101LogicalResult RescaleOp::verify() {
4102 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
4103 if (!inputType) {
4104 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
4105 return failure();
4106 }
4107
4108 auto inputElementType =
4109 getStorageElementTypeOrSelf(inputType.getElementType());
4110 if (!mlir::isa<IntegerType>(inputElementType)) {
4111 emitOpError("expect input to have integer element type, got ")
4112 << inputElementType;
4113 return failure();
4114 }
4115
4116 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
4117 if (!outputType) {
4118 emitOpError("expect shaped tensor for output, got ")
4119 << getOutput().getType();
4120 return failure();
4121 }
4122
4123 auto outputElementType =
4124 getStorageElementTypeOrSelf(outputType.getElementType());
4125 if (!mlir::isa<IntegerType>(outputElementType)) {
4126 emitOpError("expect output to have integer element type, got ")
4127 << outputElementType;
4128 return failure();
4129 }
4130
4131 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4132 .failed())
4133 return failure();
4134
4135 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4136 .failed())
4137 return failure();
4138
4139 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4140 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4141 return failure();
4142
4143 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4144 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4145 return failure();
4146
4147 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
4148 if (!multiplierType) {
4149 emitOpError("expect shaped tensor for multiplier, got ")
4150 << getMultiplier().getType();
4151 return failure();
4152 }
4153
4154 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
4155 if (!shiftType) {
4156 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
4157 return failure();
4158 }
4159
4160 // multiplier element type must be i32 for scale32 = true
4161 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4162 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4163 << multiplierType.getElementType();
4164 return failure();
4165 }
4166
4167 // multiplier element type must be i16 for scale32 = false
4168 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4170 "expect i16 element type for multiplier for scale32=false, got ")
4171 << multiplierType.getElementType();
4172 return failure();
4173 }
4174
4175 if (!inputType.hasRank())
4176 return success();
4177
4178 // multiplier/shift must have shape = {numChannels},
4179 // where numChannel is 1 if per_channel = false
4180 // otherwise numChannel is dimension in input shape's last axis
4181 int64_t numChannels = 1;
4182 if (getPerChannel()) {
4183 if (inputType.getRank() < 1) {
4184 emitOpError("requires input to be at least rank 1 when per_channel is "
4185 "true, but got rank ")
4186 << inputType.getRank();
4187 return failure();
4188 }
4189 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4190 }
4191
4192 if (!multiplierType.hasRank())
4193 return success();
4194
4195 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4196 // multiplier input has rank 1 by dialect definition
4197 if (multiplierShape[0] != ShapedType::kDynamic &&
4198 multiplierShape[0] != numChannels) {
4199 emitOpError("expect shape of { ")
4200 << numChannels << " } for multiplier input, got { "
4201 << multiplierShape[0] << " }";
4202 return failure();
4203 }
4204
4205 if (!shiftType.hasRank())
4206 return success();
4207
4208 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4209 // shift input has rank 1 by dialect definition
4210 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4211 emitOpError("expect shape of { ")
4212 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
4213 return failure();
4214 }
4215
4216 return success();
4217}
4218
4219LogicalResult RescaleOp::inferReturnTypeComponents(
4220 MLIRContext *context, ::std::optional<Location> location,
4221 RescaleOp::Adaptor adaptor,
4222 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4223 ShapeAdaptor inputShape(adaptor.getInput().getType());
4224 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4225 return success();
4226}
4227
4228LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4229 MLIRContext *context, ::std::optional<Location> location,
4230 CastFromBlockScaledOp::Adaptor adaptor,
4231 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4232 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4233 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4234 return success();
4235}
4236
4237LogicalResult CastFromBlockScaledOp::verify() {
4238 const Type inputDataType = getInputData().getType();
4239 const Type outputDataType = getResult().getType();
4240 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4241 return emitOpError() << "require compatible shapes for input_data ("
4242 << inputDataType << ") and " << "output_data ("
4243 << outputDataType << ")";
4244
4245 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4246
4247 if (inputDataShape.hasRank()) {
4248 const unsigned int blockSize =
4249 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4250 const int64_t inputDataLastDim =
4251 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4252 if (inputDataLastDim % blockSize != 0)
4253 return emitOpError() << "expect last dimension of input_data ("
4254 << inputDataLastDim
4255 << ") to be divisible by block_size (" << blockSize
4256 << ")";
4257
4258 const Type inputScaleType = getInputScale().getType();
4259 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4260
4261 if (inputScaleShape.hasRank()) {
4262 SmallVector<int64_t> inputDataDims, inputScaleDims;
4263 inputDataShape.getDims(inputDataDims);
4264 inputScaleShape.getDims(inputScaleDims);
4265
4266 if (inputDataDims.size() != inputScaleDims.size() ||
4268 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4269 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4270 return emitOpError()
4271 << "require compatible shapes for input_data (" << inputDataType
4272 << ") and " << "input_scale (" << inputScaleType
4273 << ") except for the last dimension";
4274
4275 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4276 inputScaleDims.back()};
4277 if (ShapedType::isStatic(inputDataLastDim) &&
4278 failed(verifyCompatibleDims(dimsToCheck)))
4279 return emitOpError()
4280 << "expect last dimension of input_scale ("
4281 << inputScaleDims.back()
4282 << ") to be equal to last dimension of input_data / block_size ("
4283 << inputDataDims.back() / blockSize << ")";
4284 }
4285 }
4286
4287 return success();
4288}
4289
4290LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4291 MLIRContext *context, ::std::optional<Location> location,
4292 CastToBlockScaledOp::Adaptor adaptor,
4293 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4294 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4295 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4296 if (!inputShape.hasRank())
4297 return success();
4298
4299 // Calculate output_scale shape if ranked input provided
4300 SmallVector<int64_t> outputScaleShape;
4301 inputShape.getDims(outputScaleShape);
4302 const int64_t lastDimLoc = inputShape.getRank() - 1;
4303 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4304 if (ShapedType::isStatic(lastDimSize)) {
4305 const unsigned int blockSize =
4306 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4307 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4308 }
4309 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4310 return success();
4311}
4312
4313LogicalResult CastToBlockScaledOp::verify() {
4314 const Type inputDataType = getInputData().getType();
4315 const Type outputDataType = getResult(0).getType();
4316 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4317 return emitOpError() << "require compatible shapes for input_data ("
4318 << inputDataType << ") and " << "output_data ("
4319 << outputDataType << ")";
4320
4321 const unsigned int blockSize =
4322 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4323 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4324 if (inputDataShape.hasRank()) {
4325 const int64_t inputDataLastDim =
4326 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4327 if (ShapedType::isStatic(inputDataLastDim) &&
4328 inputDataLastDim % blockSize != 0)
4329 return emitOpError() << "expect last dimension of input_data ("
4330 << inputDataLastDim
4331 << ") to be divisible by block_size (" << blockSize
4332 << ")";
4333 }
4334
4335 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4336 const Type outputScaleType = getResult(1).getType();
4337 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4338 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4339 SmallVector<int64_t> outputDataDims, outputScaleDims;
4340 outputDataShape.getDims(outputDataDims);
4341 outputScaleShape.getDims(outputScaleDims);
4342
4343 if (outputDataDims.size() != outputScaleDims.size() ||
4345 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4346 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4347 return emitOpError() << "require compatible shapes for output_data ("
4348 << outputDataType << ") and " << "output_scale ("
4349 << outputScaleType
4350 << ") except for the last dimension";
4351
4352 const int64_t outputDataLastDim = outputDataDims.back();
4353 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4354 outputScaleDims.back()};
4355 if (ShapedType::isStatic(outputDataLastDim) &&
4356 failed(verifyCompatibleDims(dimsToCheck)))
4357 return emitOpError()
4358 << "expect last dimension of output_scale ("
4359 << outputScaleDims.back()
4360 << ") to be equal to last dimension of output_data / block_size ("
4361 << outputDataDims.back() / blockSize << ")";
4362 }
4363
4364 return success();
4365}
4366
4367LogicalResult IfOp::inferReturnTypeComponents(
4368 MLIRContext *context, ::std::optional<Location> location,
4369 IfOp::Adaptor adaptor,
4370 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4371 llvm::SmallVector<tosa::YieldOp> yieldOps;
4372 for (Region *region : adaptor.getRegions()) {
4373 for (auto &block : *region)
4374 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4375 yieldOps.push_back(returnOp);
4376 }
4377
4378 if (yieldOps.empty())
4379 return failure();
4380
4381 // Get the initial type information for the yield op.
4382 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4383 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4384 for (auto operand : yieldOps.front().getOperands()) {
4385 resultKnowledge.push_back(
4386 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4387 }
4388
4389 for (auto yieldOp : yieldOps) {
4390 if (resultKnowledge.size() != yieldOp.getNumOperands())
4391 return failure();
4392
4393 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4394 int32_t index = it.index();
4395 auto meet = ValueKnowledge::meet(
4396 resultKnowledge[index],
4397 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4398 if (!meet)
4399 continue;
4400 resultKnowledge[index] = meet;
4401 }
4402 }
4403
4404 for (const ValueKnowledge &result : resultKnowledge) {
4405 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4406 }
4407
4408 return success();
4409}
4410
4411LogicalResult WhileOp::inferReturnTypeComponents(
4412 MLIRContext *context, ::std::optional<Location> location,
4413 WhileOp::Adaptor adaptor,
4414 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4415 llvm::SmallVector<tosa::YieldOp> yieldOps;
4416 for (auto &block : adaptor.getBodyGraph())
4417 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4418 yieldOps.push_back(returnOp);
4419
4420 // TOSA's while must have a tosa.yield as its terminator. If not found this
4421 // tosa.while is invalid.
4422 if (yieldOps.empty())
4423 return failure();
4424
4425 // Get the initial type information from the operand types.
4426 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4427 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4428 for (auto operand : yieldOps.front().getOperands()) {
4429 resultKnowledge.push_back(
4430 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4431 }
4432
4433 for (auto yieldOp : yieldOps) {
4434 if (resultKnowledge.size() != yieldOp.getNumOperands())
4435 return failure();
4436
4437 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4438 int32_t index = it.index();
4439 if (auto meet = ValueKnowledge::meet(
4440 resultKnowledge[index],
4441 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4442 resultKnowledge[index] = meet;
4443 }
4444 }
4445 }
4446
4447 for (const ValueKnowledge &result : resultKnowledge) {
4448 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4449 }
4450
4451 return success();
4452}
4453
4454std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4455 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4456 return llvm::to_vector<4>(vt.getShape());
4457 return std::nullopt;
4458}
4459
4461 Block::BlockArgListType blocksArgs,
4462 ValueRange initializers,
4463 StringRef prefix = "") {
4464 assert(blocksArgs.size() == initializers.size() &&
4465 "expected same length of arguments and initializers");
4466 if (initializers.empty())
4467 return;
4468
4469 parser << prefix << '(';
4470 llvm::interleaveComma(
4471 llvm::zip(blocksArgs, initializers), parser,
4472 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4473 parser << ")";
4474}
4475
4476// parse and print of IfOp refer to the implementation of SCF dialect.
4477ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4478 // Create the regions for 'then'.
4479 result.regions.reserve(2);
4480 Region *thenRegion = result.addRegion();
4481 Region *elseRegion = result.addRegion();
4482
4483 OpAsmParser::UnresolvedOperand cond;
4484
4485 if (parser.parseOperand(cond))
4486 return failure();
4487
4488 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4489 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4490
4491 // Parse the optional block arguments
4492 OptionalParseResult listResult =
4493 parser.parseOptionalAssignmentList(regionArgs, operands);
4494 if (listResult.has_value() && failed(listResult.value()))
4495 return failure();
4496
4497 // Parse a colon.
4498 if (failed(parser.parseColon()))
4499 return parser.emitError(parser.getCurrentLocation(),
4500 "expected type for condition operand");
4501
4502 // Parse the type of the condition operand
4503 Type condType;
4504 if (failed(parser.parseType(condType)))
4505 return parser.emitError(parser.getCurrentLocation(),
4506 "expected type for condition operand");
4507
4508 // Resolve operand with provided type
4509 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4510 return failure();
4511
4512 // Parse optional block arg types
4513 if (listResult.has_value()) {
4514 FunctionType functionType;
4515
4516 if (failed(parser.parseType(functionType)))
4517 return parser.emitError(parser.getCurrentLocation())
4518 << "expected list of types for block arguments "
4519 << "followed by arrow type and list of return types";
4520
4521 result.addTypes(functionType.getResults());
4522
4523 if (functionType.getNumInputs() != operands.size()) {
4524 return parser.emitError(parser.getCurrentLocation())
4525 << "expected as many input types as operands " << "(expected "
4526 << operands.size() << " got " << functionType.getNumInputs()
4527 << ")";
4528 }
4529
4530 // Resolve input operands.
4531 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4532 parser.getCurrentLocation(),
4533 result.operands)))
4534 return failure();
4535 } else {
4536 // Parse optional results type list.
4537 if (parser.parseOptionalArrowTypeList(result.types))
4538 return failure();
4539 }
4540
4541 // Parse the 'then' region.
4542 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4543 return failure();
4544
4545 // If we find an 'else' keyword then parse the 'else' region.
4546 if (!parser.parseOptionalKeyword("else")) {
4547 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4548 return failure();
4549 }
4550
4551 // Parse the optional attribute list.
4552 if (parser.parseOptionalAttrDict(result.attributes))
4553 return failure();
4554 return success();
4555}
4556
4557void IfOp::print(OpAsmPrinter &p) {
4558 p << " " << getCondition();
4559
4560 printInitializationList(p, getThenGraph().front().getArguments(),
4561 getInputList(), " ");
4562 p << " : ";
4563 p << getCondition().getType();
4564
4565 if (!getInputList().empty()) {
4566 p << " (";
4567 llvm::interleaveComma(getInputList().getTypes(), p);
4568 p << ")";
4569 }
4570 p.printArrowTypeList(getResultTypes());
4571 p << " ";
4572
4573 p.printRegion(getThenGraph());
4574
4575 // Print the 'else' regions if it exists and has a block.
4576 auto &elseRegion = getElseGraph();
4577 if (!elseRegion.empty()) {
4578 p << " else ";
4579 p.printRegion(elseRegion);
4580 }
4581
4582 p.printOptionalAttrDict((*this)->getAttrs());
4583}
4584
4585LogicalResult IfOp::verify() {
4586 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4587 "'then_graph' arguments", getInputList(),
4588 "'input_list'")
4589 .failed())
4590 return failure();
4591
4592 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4593 "'else_graph' arguments", getInputList(),
4594 "'input_list'")
4595 .failed())
4596 return failure();
4597
4598 // MLIR will verify the absence of the terminator for us if otherwise.
4599 if (getThenGraph().front().mightHaveTerminator()) {
4600 auto thenYield =
4601 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4602 if (thenYield && errorIfTypeOrShapeMismatch(
4603 *this, thenYield.getInputs(), "'then_graph' results",
4604 getOutputList(), "'output_list'")
4605 .failed())
4606 return failure();
4607 }
4608
4609 // MLIR will verify the absence of the terminator for us if otherwise.
4610 if (getElseGraph().front().mightHaveTerminator()) {
4611 auto elseYield =
4612 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4613 if (elseYield && errorIfTypeOrShapeMismatch(
4614 *this, elseYield.getInputs(), "'else_graph' results",
4615 getOutputList(), "'output_list'")
4616 .failed())
4617 return failure();
4618 }
4619
4620 auto condType = getCondition().getType();
4621 if (errorIfShapeNotSizeOne(*this, condType).failed())
4622 return emitOpError() << "'condition' must be a size 1 tensor, got "
4623 << condType;
4624
4625 return success();
4626}
4627
4628LogicalResult WhileOp::verify() {
4629 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4630 getOutputList(), "'output_list'")
4631 .failed())
4632 return failure();
4633
4634 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4635 "'cond_graph' arguments", getInputList(),
4636 "'input_list'")
4637 .failed())
4638 return failure();
4639
4640 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4641 "'body_graph' arguments", getInputList(),
4642 "'input_list'")
4643 .failed())
4644 return failure();
4645
4646 if (getBodyGraph().front().mightHaveTerminator()) {
4647 auto bodyYield =
4648 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4649 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4650 "'body_graph' results",
4651 getInputList(), "'input_list'")
4652 .failed())
4653 return failure();
4654 }
4655
4656 // Condition block output must be a single element tensor with a single bool
4657 // value.
4658 if (!getCondGraph().front().mightHaveTerminator())
4659 return success();
4660
4661 auto condYield =
4662 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4663 if (!condYield)
4664 return success();
4665
4666 if (condYield.getInputs().size() != 1)
4667 return emitOpError() << "require 'cond_graph' only have one result";
4668
4669 auto condOutType = condYield.getInputs()[0].getType();
4670 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4671 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4672 << condOutType;
4673
4674 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4675 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4676 << condOutType;
4677
4678 return success();
4679}
4680
4681LogicalResult ReverseOp::verify() {
4682 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4683 /* outType = */ getOutput().getType())
4684 .failed())
4685 return failure();
4686 TensorType inputType = getInput1().getType();
4687 TensorType outputType = getOutput().getType();
4688 int32_t reverseAxis = getAxis();
4689
4690 if (reverseAxis < 0)
4691 return emitOpError("expected non-negative reverse axis");
4692 if (inputType.hasRank()) {
4693 int64_t inputRank = inputType.getRank();
4694 // We allow for a special case where the input/output shape has rank 0 and
4695 // axis is also 0.
4696 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4697 return emitOpError("expect input tensor rank (")
4698 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4699 << ")";
4700 }
4701 if (outputType.hasRank()) {
4702 int64_t outputRank = outputType.getRank();
4703 if (inputType.hasRank() && outputRank != inputType.getRank())
4704 return emitOpError(
4705 "expect output tensor rank to be equal to input tensor rank");
4706 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4707 return emitOpError("expect output tensor rank (")
4708 << outputRank << ") to be larger than reverse axis ("
4709 << reverseAxis << ")";
4710 }
4711 return success();
4712}
4713
4714LogicalResult tosa::SelectOp::verify() {
4715 // verify input2 and input3 have same element type as output
4716 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4717 /* outType = */ getOutput().getType())
4718 .failed() ||
4719 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4720 /* outType = */ getOutput().getType())
4721 .failed()) {
4722 return failure();
4723 }
4724 // verify input1 has element type of bool
4725 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4726 if (!predicateType) {
4727 return emitOpError("expect shaped tensor for input1, got ")
4728 << getInput1().getType();
4729 }
4730 auto predicateElementType = predicateType.getElementType();
4731 if (!predicateElementType.isInteger(1)) {
4732 return emitOpError("expect element type of bool for input1, got ")
4733 << predicateElementType;
4734 }
4735
4736 return success();
4737}
4738
4739LogicalResult tosa::VariableReadOp::verify() {
4740 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4741 .failed())
4742 return failure();
4743
4744 return success();
4745}
4746
4747LogicalResult tosa::VariableWriteOp::verify() {
4748 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4749 .failed())
4750 return failure();
4751
4752 return success();
4753}
4754
4755// parse and print of WhileOp refer to the implementation of SCF dialect.
4756ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4757 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4758 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4759 Region *cond = result.addRegion();
4760 Region *body = result.addRegion();
4761
4762 OptionalParseResult listResult =
4763 parser.parseOptionalAssignmentList(regionArgs, operands);
4764 if (listResult.has_value() && failed(listResult.value()))
4765 return failure();
4766
4767 FunctionType functionType;
4768 SMLoc typeLoc = parser.getCurrentLocation();
4769 if (failed(parser.parseColonType(functionType)))
4770 return failure();
4771
4772 result.addTypes(functionType.getResults());
4773
4774 if (functionType.getNumInputs() != operands.size()) {
4775 return parser.emitError(typeLoc)
4776 << "expected as many input types as operands " << "(expected "
4777 << operands.size() << " got " << functionType.getNumInputs() << ")";
4778 }
4779
4780 // Resolve input operands.
4781 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4782 parser.getCurrentLocation(),
4783 result.operands)))
4784 return failure();
4785
4786 // Propagate the types into the region arguments.
4787 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4788 regionArgs[i].type = functionType.getInput(i);
4789
4790 return failure(parser.parseRegion(*cond, regionArgs) ||
4791 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4792 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4793}
4794
4795void WhileOp::print(OpAsmPrinter &parser) {
4796 printInitializationList(parser, getCondGraph().front().getArguments(),
4797 getInputList(), " ");
4798 parser << " : ";
4799 parser.printFunctionalType(getInputList().getTypes(),
4800 getResults().getTypes());
4801 parser << ' ';
4802 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4803 parser << " do ";
4804 parser.printRegion(getBodyGraph());
4805 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4806}
4807
4808// Create a rank-1 const tensor for zero point of the source tensor.
4809std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4810 Location loc,
4811 Type srcElemType,
4812 int64_t zp) {
4813 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4814 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4815 if (llvm::isa<FloatType>(srcElemType)) {
4816 auto zpAttr = DenseElementsAttr::get(
4817 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4818 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4819 }
4820 if (llvm::isa<IntegerType>(srcElemType)) {
4821 auto zpAttr =
4822 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4823 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4824 }
4825 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4826 return std::nullopt;
4827}
4828
4829//===----------------------------------------------------------------------===//
4830// TOSA Shape and Shape Operators Helper functions.
4831//===----------------------------------------------------------------------===//
4832
4834 return mlir::isa<tosa::shapeType>(t);
4835}
4836
4837LogicalResult
4838mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4839 int rank) {
4840 if (rank < 0)
4841 return emitError() << "invalid rank (must be >= 0): " << rank;
4842 return success();
4843}
4844
4846 for (auto v : op->getOperands()) {
4847 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4848 Operation *definingOp = v.getDefiningOp();
4849 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4850 return op->emitOpError("shape operand is not compile time resolvable");
4851 }
4852 }
4853 }
4854 return success();
4855}
4856
4857LogicalResult
4859 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
4860 return failure();
4861
4862 // delegate function that returns rank of shape type
4863 auto getRank = [](const Type type) {
4864 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4865 };
4866 auto operandTypes = op->getOperandTypes();
4867 auto resultTypes = op->getResultTypes();
4868
4869 auto rank = getRank(*op->getOperandTypes().begin());
4870 for (auto type : operandTypes) {
4871 if (getRank(type) != rank) {
4872 return op->emitOpError("operands don't have matching ranks");
4873 }
4874 }
4875 for (auto type : resultTypes) {
4876 if (getRank(type) != rank) {
4877 return op->emitOpError("result shape has different rank than operands");
4878 }
4879 }
4880 return success();
4881}
4882
4883//===----------------------------------------------------------------------===//
4884// TOSA Shape Operators verify functions.
4885//===----------------------------------------------------------------------===//
4886
4887LogicalResult tosa::ConstShapeOp::verify() {
4888 // check one dimensional rank
4889 auto valuesRank = getValues().getType().getRank();
4890 if (valuesRank != 1)
4891 return emitOpError("expect elements in attribute values with rank 1");
4892 // check that number of elements in values attr equal to rank of result shape
4893 auto count = getValues().getNumElements();
4894 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4895 if (count != rank && (count != 1 || rank != 0)) {
4896 return emitOpError("expect number of elements in attribute values (")
4897 << count << ") to be equal to the rank (" << rank
4898 << ") for the result shape type";
4899 }
4900 return success();
4901}
4902
4903LogicalResult tosa::DimOp::verify() {
4904 const tosa::shapeType outShapeType =
4905 cast<tosa::shapeType>(getResult().getType());
4906 if (outShapeType.getRank() != 1)
4907 return emitOpError("expect output shape type to contain one element, got ")
4908 << outShapeType;
4909
4910 const ShapeAdaptor inputType(getInput1().getType());
4911 if (inputType.hasRank()) {
4912 const int64_t inputRank = inputType.getRank();
4913 const int64_t axis = getAxisAttr().getInt();
4914 if (axis < 0 || axis >= inputRank)
4915 return emitOpError("expect axis to be in the range [0, ")
4916 << inputRank << "), got " << axis;
4917 }
4918 return success();
4919}
4920
4921LogicalResult tosa::ConcatShapeOp::verify() {
4922 const tosa::shapeType outShapeType =
4923 cast<tosa::shapeType>(getResult().getType());
4924 const int64_t outputRank = outShapeType.getRank();
4925 const Operation::operand_range inputList = getInput();
4926
4927 if (inputList.size() == 0)
4928 return emitOpError("requires at least one input shape");
4929
4930 if (llvm::any_of(inputList, [](Value v) {
4931 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
4932 }))
4933 return emitOpError("requires all inputs shapes have a rank greater than 0");
4934
4935 const int64_t inputsRank =
4936 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
4937 const tosa::shapeType inShapeType =
4938 cast<tosa::shapeType>(input.getType());
4939 return acc + inShapeType.getRank();
4940 });
4941 if (outputRank != inputsRank)
4942 return emitOpError("requires output shape rank to be equal to the sum of "
4943 "the input shape ranks (")
4944 << inputsRank << "), got " << outputRank;
4945
4946 return success();
4947}
4948
4949LogicalResult tosa::SliceShapeOp::verify() {
4950 std::optional<int32_t> start;
4951 DenseIntElementsAttr startAttr;
4952 if (matchPattern(getStart(), m_Constant(&startAttr)))
4953 start = startAttr.getValues<int32_t>()[0];
4954 if (start && start.value() < 0)
4955 return emitOpError("expected non-negative start index, got ")
4956 << start.value();
4957
4958 std::optional<int32_t> size;
4959 DenseIntElementsAttr sizeAttr;
4960 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
4961 size = sizeAttr.getValues<int32_t>()[0];
4962 if (size && size.value() <= 0)
4963 return emitOpError("expected positive size, got ") << size.value();
4964
4965 if (!size)
4966 return success();
4967
4968 const tosa::shapeType outShapeType =
4969 cast<tosa::shapeType>(getResult().getType());
4970 const int64_t outputRank = outShapeType.getRank();
4971 if (outputRank != size)
4972 return emitOpError(
4973 "expected output type size to be equal to size attribute, got ")
4974 << outputRank << " vs " << size.value();
4975
4976 if (!start)
4977 return success();
4978
4979 const tosa::shapeType inShapeType =
4980 cast<tosa::shapeType>(getInput().getType());
4981 const int64_t inputRank = inShapeType.getRank();
4982 const int64_t sliceSize = start.value() + size.value();
4983 if (sliceSize > inputRank)
4984 return emitOpError("expected start + size to be less than or equal to "
4985 "input shape rank (")
4986 << inputRank << "), got " << sliceSize;
4987
4988 return success();
4989}
4990
4991//===----------------------------------------------------------------------===//
4992// TOSA Attribute Definitions.
4993//===----------------------------------------------------------------------===//
4994
4995#define GET_ATTRDEF_CLASSES
4996#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4997
4998//===----------------------------------------------------------------------===//
4999// TOSA Type Definitions.
5000//===----------------------------------------------------------------------===//
5001#define GET_TYPEDEF_CLASSES
5002#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5003
5004//===----------------------------------------------------------------------===//
5005// TOSA Operator Definitions.
5006//===----------------------------------------------------------------------===//
5007
5008#define GET_OP_CLASSES
5009#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:480
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1326
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1015
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:138
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3199
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:585
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:979
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3224
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:680
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:988
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3414
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1440
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:568
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1454
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:643
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3249
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3317
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2696
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1304
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:834
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2626
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:628
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:785
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3305
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:574
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3215
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1498
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1400
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1280
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1355
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:937
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1484
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2654
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1081
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1584
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OperandRange operand_range
Definition Operation.h:371
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4858
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4845
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:227
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:252
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:620
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4809
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4833
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:605
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45