MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40
41//===----------------------------------------------------------------------===//
42// Tosa dialect interface includes.
43//===----------------------------------------------------------------------===//
44
45#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49
50namespace {
51#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52
53//===----------------------------------------------------------------------===//
54// Dialect Function Inliner Interface.
55//===----------------------------------------------------------------------===//
56struct TosaInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
58
59 //===--------------------------------------------------------------------===//
60 // Analysis Hooks.
61 //===--------------------------------------------------------------------===//
62
63 /// All operations can be inlined by default.
64 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65 IRMapping &map) const final {
66 return true;
67 }
68
69 /// All regions with If and While parent operators can be inlined.
70 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71 IRMapping &map) const final {
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
74 }
75};
76
77/// This class implements the bytecode interface for the Tosa dialect.
78struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79 TosaDialectBytecodeInterface(Dialect *dialect)
80 : BytecodeDialectInterface(dialect) {}
81
82 //===--------------------------------------------------------------------===//
83 // Attributes
84
85 Attribute readAttribute(DialectBytecodeReader &reader) const override {
86 return ::readAttribute(getContext(), reader);
87 }
88
89 LogicalResult writeAttribute(Attribute attr,
90 DialectBytecodeWriter &writer) const override {
91 return ::writeAttribute(attr, writer);
92 }
93
94 //===--------------------------------------------------------------------===//
95 // Types
96
97 Type readType(DialectBytecodeReader &reader) const override {
98 return ::readType(getContext(), reader);
99 }
100
101 LogicalResult writeType(Type type,
102 DialectBytecodeWriter &writer) const override {
103 return ::writeType(type, writer);
104 }
105
106 void writeVersion(DialectBytecodeWriter &writer) const final {
107 // TODO: Populate.
108 }
109
110 std::unique_ptr<DialectVersion>
111 readVersion(DialectBytecodeReader &reader) const final {
112 // TODO: Populate
113 reader.emitError("Dialect does not support versioning");
114 return nullptr;
115 }
116
117 LogicalResult upgradeFromVersion(Operation *topLevelOp,
118 const DialectVersion &version) const final {
119 return success();
120 }
121};
122
123} // namespace
124
125//===----------------------------------------------------------------------===//
126// TOSA control flow support.
127//===----------------------------------------------------------------------===//
128
129/// Returns the while loop body.
130SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131 return {&getBodyGraph()};
132}
133
134//===----------------------------------------------------------------------===//
135// TOSA variable operator support.
136//===----------------------------------------------------------------------===//
137
139 return map_to_vector(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
141 });
142}
143
144// returns type of variable op
145RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146 Type elementType = variableOp.getType();
147 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149 return RankedTensorType::get(shape, elementType);
150}
151
152//===----------------------------------------------------------------------===//
153// Tosa dialect initialization.
154//===----------------------------------------------------------------------===//
155
156void TosaDialect::initialize() {
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160 >();
161 addOperations<
162#define GET_OP_LIST
163#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 >();
165 addAttributes<
166#define GET_ATTRDEF_LIST
167#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168 >();
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
178}
179
180Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
181 Type type, Location loc) {
182 // Tosa dialect constants only support ElementsAttr unlike standard dialect
183 // constant which supports all attributes.
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return tosa::ConstShapeOp::create(builder, loc, type,
186 llvm::cast<DenseIntElementsAttr>(value));
187 }
188 if (llvm::isa<ElementsAttr>(value))
189 return tosa::ConstOp::create(builder, loc, type,
190 llvm::cast<ElementsAttr>(value));
191 return nullptr;
192}
193
194//===----------------------------------------------------------------------===//
195// Parsers and printers
196//===----------------------------------------------------------------------===//
197
198namespace {
199
200ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201 DenseElementsAttr &varShapeAttr,
202 TypeAttr &typeAttr) {
203 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
205 return parser.emitError(parser.getCurrentLocation())
206 << "expected ranked type";
207
208 auto elementType = shapedType.getElementType();
209 typeAttr = TypeAttr::get(elementType);
210 ArrayRef<int64_t> shape = shapedType.getShape();
211 Builder builder(parser.getContext());
212 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213 return success();
214 }
215 return parser.emitError(parser.getCurrentLocation())
216 << "expected shaped type";
217}
218
219} // namespace
220
221// parses the optional initial value or type for a tosa variable
222// with initial value:
223// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224//
225// without initial value:
226// tosa.variable @name : tensor<1x8xf32>
228 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229 Attribute &initialValueAttr) {
230 if (succeeded(parser.parseOptionalEqual())) {
231 if (failed(parser.parseAttribute(initialValueAttr))) {
232 return parser.emitError(parser.getCurrentLocation())
233 << "expected attribute";
234 }
235 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237 typeAttr);
238 }
239 return parser.emitError(parser.getCurrentLocation())
240 << "expected Typed attr";
241 }
242
243 initialValueAttr = nullptr;
244 Type parsedType;
245 if (failed(parser.parseColonType(parsedType))) {
246 return parser.emitError(parser.getCurrentLocation())
247 << "expected type after colon";
248 }
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250}
251
253 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254 TypeAttr typeAttr, Attribute initialValueAttr) {
255 bool needsSpace = false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 auto shape =
258 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
261 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262 auto tensorTypeAttr = TypeAttr::get(tensorType);
263 p << ": ";
264 p.printAttribute(tensorTypeAttr);
265 needsSpace = true; // subsequent attr value needs a space separator
266 }
267 if (initialValueAttr) {
268 if (needsSpace)
269 p << ' ';
270 p << "= ";
271 p.printAttribute(initialValueAttr);
272 }
273}
274
275namespace {
276
277// parse attributes with special handling for tosa enum attributes
278template <typename EnumType>
279ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
280 NamedAttrList &outAttrs) {
281 llvm::StringRef name;
282 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
283 return failure();
284
285 // special handling: rounding_mode accepts a *bare* RoundingMode enum
286 // keyword.
287 llvm::StringRef kw;
288 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
289 if (name == "rounding_mode" &&
290 succeeded(parser.parseOptionalKeyword(&kw))) {
291 auto sym = symbolizeRoundingMode(kw);
292 if (!sym)
293 return parser.emitError(parser.getCurrentLocation())
294 << "invalid rounding_mode value: " << kw;
295 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
296 outAttrs.push_back(NamedAttribute(name, attr));
297 return success();
298 }
299 }
300 // special handling: mode accepts a *bare* ResizeMode enum keyword.
301 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
302 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
303 auto sym = symbolizeResizeMode(kw);
304 if (!sym)
305 return parser.emitError(parser.getCurrentLocation())
306 << "invalid resize mode value: " << kw;
307 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
308 outAttrs.push_back(NamedAttribute(name, attr));
309 return success();
310 }
311 }
312 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
313 // keyword.
314 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
315 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
316 auto sym = symbolizeNanPropagationMode(kw);
317 if (!sym)
318 return parser.emitError(parser.getCurrentLocation())
319 << "invalid nan_mode value: " << kw;
320 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
321 outAttrs.push_back(NamedAttribute(name, attr));
322 return success();
323 }
324 }
325
326 // special handling: block_size accepts a *bare* BlockSizeMode enum
327 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
328 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
329 auto sym = symbolizeBlockSize(kw);
330 if (!sym)
331 return parser.emitError(parser.getCurrentLocation())
332 << "invalid block_size value: " << kw;
333 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
334 outAttrs.push_back(NamedAttribute(name, attr));
335 return success();
336 }
337 }
338
339 // Default path: parse any normal attribute literal, including fully qualified
340 // enum keyword
341 Attribute attr;
342 return parser.parseAttribute(attr, name, outAttrs);
343}
344
345template <typename EnumType>
346ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
347 // parse operands
349 if (parser.parseCommaSeparatedList(
350 [&]() { return parser.parseOperand(operands.emplace_back()); }))
351 return failure();
352
353 // Parse { attr-dict } with special handling for enum bare token
354 NamedAttrList attrs;
355 if (succeeded(parser.parseOptionalLBrace()) &&
356 failed(parser.parseOptionalRBrace())) {
357 do {
358 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
359 return failure();
360 } while (succeeded(parser.parseOptionalComma()));
361 if (parser.parseRBrace())
362 return failure();
363 }
364
365 FunctionType fnTy;
366 if (parser.parseColonType(fnTy))
367 return failure();
368
369 // Resolve operands and types
370 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
371 parser.getCurrentLocation(),
372 result.operands)))
373 return failure();
374
375 result.addTypes(fnTy.getResults());
376 result.addAttributes(attrs);
377
378 return success();
379}
380
381void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
382 parser << namedAttr.getName().strref() << " = ";
383 auto attr = namedAttr.getValue();
384 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
385 parser << roundingModeAttr.getValue();
386 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
387 parser << resizeModeAttr.getValue();
388 } else if (auto nanPropagationModeAttr =
389 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
390 parser << nanPropagationModeAttr.getValue();
391 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
392 parser << blockSizeAttr.getValue();
393 } else {
394 parser.printAttribute(attr);
395 }
396}
397
398// print with special handling for default valued NanPropagationMode attribute
399void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
400 parser << " ";
401 parser.printOperands(op->getOperands());
402
403 NamedAttrList toPrint(op->getAttrs());
404 // remove default NanPropagate attribute
405 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
406 for (auto attr : op->getAttrs()) {
407 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
408 if (nanAttr.getValue() == kDefaultNanValue) {
409 // elide from toPrint
410 toPrint.erase(attr.getName());
411 break;
412 }
413 }
414 }
415
416 if (!toPrint.empty()) {
417 parser << " {";
418 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
419 printNamedAttr(parser, namedAttr);
420 });
421 parser << "}";
422 }
423
424 parser << " : ";
425 parser.printFunctionalType(op);
426}
427
428// print with special handling for enums: RoundingMode, ResizeMode
429void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
430 parser << " ";
431 parser.printOperands(op->getOperands());
432
433 if (!op->getAttrs().empty()) {
434 parser << " {";
435 llvm::interleaveComma(op->getAttrs(), parser,
436 [&](const NamedAttribute namedAttr) {
437 printNamedAttr(parser, namedAttr);
438 });
439 parser << "}";
440 }
441
442 parser << " : ";
443 parser.printFunctionalType(op);
444}
445
446} // namespace
447
448ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
449 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
450}
451
452void RescaleOp::print(OpAsmPrinter &parser) {
453 printWithEnumHandling(parser, *this);
454}
455
456ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
457 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
458}
459
460void ApplyScaleOp::print(OpAsmPrinter &parser) {
461 printWithEnumHandling(parser, *this);
462}
463
464ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
465 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
466}
467
468void ResizeOp::print(OpAsmPrinter &parser) {
469 printWithEnumHandling(parser, *this);
470}
471
472ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
473 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
474}
475
476void ArgMaxOp::print(OpAsmPrinter &parser) {
477 printWithNanPropagationHandling(parser, *this);
478}
479
480ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
481 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
482}
483
484void MaxPool2dOp::print(OpAsmPrinter &parser) {
485 printWithNanPropagationHandling(parser, *this);
486}
487
488ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
489 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
490}
491
492void ClampOp::print(OpAsmPrinter &parser) {
493 printWithNanPropagationHandling(parser, *this);
494}
495
496ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
497 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
498}
499
500void MaximumOp::print(OpAsmPrinter &parser) {
501 printWithNanPropagationHandling(parser, *this);
502}
503
504ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
505 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
506}
507
508void MinimumOp::print(OpAsmPrinter &parser) {
509 printWithNanPropagationHandling(parser, *this);
510}
511
512ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
513 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
514}
515
516void ReduceMaxOp::print(OpAsmPrinter &parser) {
517 printWithNanPropagationHandling(parser, *this);
518}
519
520ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
521 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
522}
523
524void ReduceMinOp::print(OpAsmPrinter &parser) {
525 printWithNanPropagationHandling(parser, *this);
526}
527
528ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
530 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
531}
532
533void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
534 printWithEnumHandling(parser, *this);
535}
536
537ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
539 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
540}
541
542void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
543 printWithEnumHandling(parser, *this);
544}
545
546ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
548 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
549}
550
551void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
552 printWithEnumHandling(parser, *this);
553}
554
555ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
557 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
558}
559
560void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
561 printWithEnumHandling(parser, *this);
562}
563
564//===----------------------------------------------------------------------===//
565// Tosa utilities.
566//===----------------------------------------------------------------------===//
567
568static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
569 if (lhs % rhs != 0)
570 return std::nullopt;
571 return lhs / rhs;
572}
573
575 auto srcType = getElementTypeOrSelf(type);
576 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
577 srcType = getStorageElementTypeFromQuantized(quantType);
578 return srcType;
579}
580
584
585static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
586 Value valZp, StringRef name) {
588 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
589
590 bool bothInts =
591 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
592 bool sameBitWidth =
593 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
594
595 if (!bothInts || !sameBitWidth) {
596 return op->emitOpError()
597 << "expected " << name << " and " << name
598 << "_zp to both be integer of the same bitwidth, but got " << eType
599 << " vs. " << eZpType;
600 }
601 return success();
602}
603
604// Create a pad-const const tensor with value of `val` of required data-type
606 Value src, int32_t val) {
607 const auto srcType = getElementTypeOrSelf(src);
608 const auto srcElemType = getStorageElementTypeOrSelf(src);
609 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
610 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
611 const auto padConstAttr{
612 llvm::isa<FloatType>(srcElemType)
613 ? DenseElementsAttr::get(padConstEType,
614 builder.getFloatAttr(srcElemType, val))
615 : DenseElementsAttr::get(padConstEType,
616 builder.getIntegerAttr(srcElemType, val))};
617 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
618}
619
621 if (dyn_cast<tosa::mxint8Type>(type))
622 return 8;
623 return type.getIntOrFloatBitWidth();
624}
625
626// Update dim size if current dim is dynamic, otherwise raise an error if sizes
627// do not match
628LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
629 const int64_t newDim,
630 const StringRef operandName,
631 const StringRef dimName) {
632 if (ShapedType::isDynamic(currDim)) {
633 currDim = newDim;
634 return success();
635 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
636 return op->emitOpError("expected ")
637 << dimName << " of " << operandName << " to match size " << currDim
638 << ", got " << newDim;
639 }
640 return success();
641}
642
644 Operation *op, const int64_t inputSize, const int64_t kernelSize,
645 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
646 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
647 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
648 const llvm::StringRef padAfterName) {
649 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
650 return success();
651
652 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
653
654 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
655 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
656 stride);
657 if (!calculatedOutSizeMinusOne.has_value())
658 return op->emitOpError("expected input_")
659 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
660 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
661 << dimAxis << " to be wholly divisible by stride_" << dimAxis
662 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
663 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
664 << ") / " << stride;
665
666 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
667 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
668 return op->emitOpError("calculated output ")
669 << dimName << " did not match expected: "
670 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
671
672 return success();
673}
674
675//===----------------------------------------------------------------------===//
676// TOSA Operator Verifiers.
677//===----------------------------------------------------------------------===//
678
679template <typename T>
680static LogicalResult verifyConvOp(T op) {
681 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
682 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
683
684 auto inputEType = inputType.getElementType();
685 auto weightEType = weightType.getElementType();
686 auto biasEType =
687 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
688 auto resultEType =
689 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
690 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
691 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
692
693 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
694 inputEType = getStorageElementTypeFromQuantized(quantType);
695
696 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
697 weightEType = getStorageElementTypeFromQuantized(quantType);
698
699 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
700 biasEType = getStorageElementTypeFromQuantized(quantType);
701
702 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
703 resultEType = getStorageElementTypeFromQuantized(quantType);
704
705 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
706 // for now, only enforce bias element type == result element type for
707 // float types.
708 op.emitOpError(
709 "expect both bias and result to have same element type, got ")
710 << biasEType << " and " << resultEType;
711 return failure();
712 }
713
714 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
715 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
716 if (inputEType != weightEType) {
717 op.emitOpError(
718 "expect both input and weight to have same element type, got ")
719 << inputEType << " and " << weightEType;
720 return failure();
721 }
722 }
723
724 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
725 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
726
727 // Either both must be float or both non-float.
728 if (inputIsFloat != weightIsFloat) {
729 op.emitOpError(
730 "expect both input and weight to be float or not together, got ")
731 << inputEType << " and " << weightEType;
732 return failure();
733 }
734
735 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
736 if (inputEType != inputZpEType) {
737 return op.emitOpError("expect both input and its zero point are the same "
738 "element type, got ")
739 << inputEType << " and " << inputZpEType;
740 }
741
742 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
743 if (weightEType != weightZpEType) {
744 return op.emitOpError("expect both weight and its zero point are the same "
745 "element type, got ")
746 << weightEType << " and " << weightZpEType;
747 }
748
749 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
750 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
751 return failure();
752
753 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
754 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
755 return failure();
756
757 return success();
758}
759
760LogicalResult tosa::ConstOp::verify() {
761
762 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
763 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
764
765 if (!attrType || !outputType) {
766 emitOpError("expected tensors for attr/result type");
767 return failure();
768 }
769
770 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
771 outputType.getElementType())) {
772 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
773 return success();
774 }
775
776 if (attrType.getElementType() != outputType.getElementType()) {
777 emitOpError("expected same attr/result element types");
778 return failure();
779 }
780
781 return success();
782}
783
784template <typename T>
785static LogicalResult verifyConvOpModes(T op) {
786 auto inputEType =
787 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
788
789 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
790 inputEType = getStorageElementTypeFromQuantized(quantType);
791
792 auto accType = op.getAccType();
793 if (inputEType.isInteger(8) && !accType.isInteger(32))
794 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
795 << accType;
796
797 if (inputEType.isInteger(16) && !accType.isInteger(48))
798 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
799 << accType;
800
801 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
802 !(accType.isF16() || accType.isF32()))
803 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
804 << accType;
805
806 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
807 return op.emitOpError(
808 "accumulator type for f16 tensor is not f16/f32, got ")
809 << accType;
810
811 if (inputEType.isBF16() && !accType.isF32())
812 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
813 << accType;
814
815 if (inputEType.isF32() && !accType.isF32())
816 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
817 << accType;
818
819 auto resultEType =
820 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
821
822 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
823 resultEType = getStorageElementTypeFromQuantized(quantType);
824
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// ERROR_IF functions.
830// ERROR_IF is a predicate that must set an error if the condition holds.
831//===----------------------------------------------------------------------===//
832
833template <typename T>
834static LogicalResult verifyConvOpErrorIf(T op) {
835 llvm::ArrayRef<int64_t> padding = op.getPad();
836 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
837 return op.emitOpError("expect all padding values to be >= 0, got ")
838 << padding;
839
840 llvm::ArrayRef<int64_t> strides = op.getStride();
841 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
842 return op.emitOpError("expect all stride values to be >= 1, got ")
843 << strides;
844
845 llvm::ArrayRef<int64_t> dilations = op.getDilation();
846 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
847 return op.emitOpError("expect all dilation values to be >= 1, got ")
848 << dilations;
849
850 const RankedTensorType outputType =
851 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
852 if (!outputType)
853 // Skip following checks if output is not ranked
854 return success();
855
856 const RankedTensorType inputType =
857 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
858 const RankedTensorType weightType =
859 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
860
861 if (inputType && weightType) {
862 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
863 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
864 if (failed(verifyConvOutputSize(
865 op, inputType.getDimSize(1), weightType.getDimSize(1),
866 outputType.getDimSize(1), padding[0], padding[1], strides[0],
867 dilations[0], "height", "y", "top", "bottom")))
868 return failure();
869
870 if (failed(verifyConvOutputSize(
871 op, inputType.getDimSize(2), weightType.getDimSize(2),
872 outputType.getDimSize(2), padding[2], padding[3], strides[1],
873 dilations[1], "width", "x", "left", "right")))
874 return failure();
875 }
876
877 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
878 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
879 if (failed(verifyConvOutputSize(
880 op, inputType.getDimSize(1), weightType.getDimSize(0),
881 outputType.getDimSize(1), padding[0], padding[1], strides[0],
882 dilations[0], "height", "y", "top", "bottom")))
883 return failure();
884
885 if (failed(verifyConvOutputSize(
886 op, inputType.getDimSize(2), weightType.getDimSize(1),
887 outputType.getDimSize(2), padding[2], padding[3], strides[1],
888 dilations[1], "width", "x", "left", "right")))
889 return failure();
890 }
891
892 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
893 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
894 if (failed(verifyConvOutputSize(
895 op, inputType.getDimSize(1), weightType.getDimSize(1),
896 outputType.getDimSize(1), padding[0], padding[1], strides[0],
897 dilations[0], "depth", "d", "front", "back")))
898 return failure();
899
900 if (failed(verifyConvOutputSize(
901 op, inputType.getDimSize(2), weightType.getDimSize(2),
902 outputType.getDimSize(2), padding[2], padding[3], strides[1],
903 dilations[1], "height", "y", "top", "bottom")))
904 return failure();
905
906 if (failed(verifyConvOutputSize(
907 op, inputType.getDimSize(3), weightType.getDimSize(3),
908 outputType.getDimSize(3), padding[4], padding[5], strides[2],
909 dilations[2], "width", "x", "left", "right")))
910 return failure();
911 }
912 }
913
914 const RankedTensorType biasType =
915 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
916 if (!biasType)
917 // Skip following checks if bias is not ranked
918 return success();
919
920 const int64_t biasChannels = biasType.getDimSize(0);
921 const int64_t outputChannels =
922 outputType.getDimSize(outputType.getRank() - 1);
923 if (biasChannels == ShapedType::kDynamic ||
924 outputChannels == ShapedType::kDynamic)
925 // Skip following checks if biasChannels or outputChannels is dynamic dim
926 return success();
927
928 if (biasChannels != outputChannels && biasChannels != 1)
929 return op.emitOpError(
930 "bias channels expected to be equal to output channels (")
931 << outputChannels << ") or 1, got " << biasChannels;
932
933 return success();
934}
935
936// Verify whether same type and shape of the given two types.
937static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
938 StringRef name1, Type type2,
939 StringRef name2) {
940 auto shapeType1 = dyn_cast<ShapedType>(type1);
941 auto shapeType2 = dyn_cast<ShapedType>(type2);
942 if (!shapeType1 || !shapeType2)
943 return failure();
944
945 auto elemType1 = shapeType1.getElementType();
946 auto elemType2 = shapeType2.getElementType();
947 if (elemType1 != elemType2)
948 return op->emitOpError()
949 << "require same element type for " << name1 << " (" << elemType1
950 << ") and " << name2 << " (" << elemType2 << ")";
951
952 if (failed(verifyCompatibleShape(type1, type2)))
953 return op->emitOpError()
954 << "require same shapes for " << name1 << " (" << type1 << ") and "
955 << name2 << " (" << type2 << ")";
956
957 return success();
958}
959
960// Verify whether same length, type, and shape of the given two tensor lists.
961static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
962 StringRef name1,
963 ValueRange list2,
964 StringRef name2) {
965 if (list1.size() != list2.size())
966 return op->emitOpError()
967 << "require same number of values in " << name1 << " ("
968 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
969
970 for (auto [type1, type2] :
971 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
972 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
973 return failure();
974 }
975
976 return success();
977}
978
979static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
980 ShapeAdaptor shapeAdaptor(type);
981 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
982 return success();
983
984 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
985}
986
987template <typename T>
988static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
989 Operation *symTableOp =
990 op->template getParentWithTrait<OpTrait::SymbolTable>();
991 if (!symTableOp)
992 // If the operation is not the scope of a symbol table, we cannot
993 // verify it against it's declaration.
994 return success();
995
996 SymbolTable symTable(symTableOp);
997 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
998
999 // Verify prior declaration
1000 if (!varOp)
1001 return op->emitOpError("'")
1002 << op.getName() << "' has not been declared by 'tosa.variable'";
1003
1004 // Verify type and shape
1005 auto variableType = getVariableType(varOp);
1006 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1007 "the input tensor")
1008 .failed())
1009 return failure();
1010 return success();
1011}
1012
1013// verify that inType and outType have same element types
1014template <typename T>
1015static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1016 StringRef aName = "input",
1017 StringRef bName = "output") {
1018 auto aTType = llvm::dyn_cast<TensorType>(aType);
1019 auto bTType = llvm::dyn_cast<TensorType>(bType);
1020 if (!aTType) {
1021 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1022 return failure();
1023 }
1024 if (!bTType) {
1025 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1026 return failure();
1027 }
1028 auto aElementType = aTType.getElementType();
1029 auto bElementType = bTType.getElementType();
1030 auto aQuantType =
1031 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1032 auto bQuantType =
1033 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1034 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1035 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1036 aElementType != bElementType) {
1037 // only check if both element types are int/index/float/UniformQuantized
1038 // eg, not sure how to check quant::QuantizedType
1039 // this happens in test_conv2d_q_grouped_convolution in
1040 // tfl-to-tosa-pipeline.mlir
1041 op.emitOpError("expect ")
1042 << aName << " and " << bName << " to have same element type, got "
1043 << aElementType << " and " << bElementType;
1044 return failure();
1045 }
1046 return success();
1047}
1048
1049LogicalResult tosa::ArgMaxOp::verify() {
1050 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1051
1052 // Ensure output is of 32-bit integer
1053 if (const auto resultETy = resultType.getElementType();
1054 !resultETy.isIntOrIndex())
1055 return emitOpError("result tensor is not of integer type");
1056
1057 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1058 if (!inputType.hasRank())
1059 return success();
1060
1061 // Ensure axis is within the tensor rank
1062 const int64_t axis = getAxisAttr().getInt();
1063 if (((axis < 0) || axis >= inputType.getRank()))
1064 return emitOpError("specified axis is outside the rank of the tensor");
1065
1066 if (!resultType.hasRank())
1067 return success();
1068
1069 const ArrayRef<int64_t> inputShape = inputType.getShape();
1070 const ArrayRef<int64_t> outputShape = resultType.getShape();
1071 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1072 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1073 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1074 return emitOpError("expected output shape '")
1075 << expectedOutputShape << "', got '" << outputShape << "'";
1076
1077 return success();
1078}
1079
1080template <typename T>
1081static LogicalResult verifyPoolingOp(T op) {
1082 const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1083 if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1084 return op.emitOpError("expect all kernel values to be >= 1, got ")
1085 << kernel;
1086
1087 const llvm::ArrayRef<int64_t> strides = op.getStride();
1088 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1089 return op.emitOpError("expect all stride values to be >= 1, got ")
1090 << strides;
1091
1092 const llvm::ArrayRef<int64_t> padding = op.getPad();
1093 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1094 return op.emitOpError("expect all padding values to be >= 0, got ")
1095 << padding;
1096
1097 // Padding must be less than kernel size to avoid a divide-by-zero
1098 const int64_t kernelX = kernel[1];
1099 const int64_t padLeft = padding[2];
1100 const int64_t padRight = padding[3];
1101 if (padRight >= kernelX || padLeft >= kernelX)
1102 return op.emitOpError("expected left/right padding to be less than the "
1103 "width of the kernel, got pad_left=")
1104 << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1105
1106 const int64_t kernelY = kernel[0];
1107 const int64_t padTop = padding[0];
1108 const int64_t padBottom = padding[1];
1109 if (padTop >= kernelY || padBottom >= kernelY)
1110 return op.emitOpError("expected top/bottom padding to be less than the "
1111 "height of the kernel, got pad_top=")
1112 << padTop << ", pad_bottom=" << padBottom
1113 << ", kernel_y=" << kernelY;
1114
1115 const auto inputType =
1116 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1117 const auto outputType =
1118 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1119 if (!inputType || !outputType)
1120 return success();
1121
1122 const auto verifyOutputSize =
1123 [&op](const int64_t inputSize, const int64_t outputSize,
1124 const int64_t kernelSize, const int64_t strideSize,
1125 const int64_t padBefore, const int64_t padAfter,
1126 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1127 const llvm::StringRef padBeforeName,
1128 const llvm::StringRef padAfterName) -> LogicalResult {
1129 if (ShapedType::isDynamic(inputSize))
1130 return success();
1131
1132 const std::optional<int64_t> calculatedOutSizeMinusOne =
1133 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1134 if (!calculatedOutSizeMinusOne.has_value())
1135 return op.emitOpError("expected input_")
1136 << dimName << " + pad_" << padBeforeName << " + pad_"
1137 << padAfterName << " - kernel_" << dimAxis
1138 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1139 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1140 << kernelSize << ") / " << strideSize;
1141
1142 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1143 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1144 return op.emitOpError("calculated output ")
1145 << dimName << " did not match expected: " << "calculated="
1146 << calculatedOutSize << ", expected=" << outputSize;
1147
1148 return success();
1149 };
1150
1151 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1152 kernel[0], strides[0], padding[0], padding[1],
1153 "height", "y", "top", "bottom")))
1154 return failure();
1155
1156 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1157 kernel[1], strides[1], padding[2], padding[3],
1158 "width", "x", "left", "right")))
1159 return failure();
1160
1161 return success();
1162}
1163
1164LogicalResult tosa::AvgPool2dOp::verify() {
1165 if (failed(verifyPoolingOp(*this)))
1166 return failure();
1167
1168 const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1169 const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1170 const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1171 const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1172
1173 auto accType = getAccType();
1174 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1175 return emitOpError("accumulator type for integer tensor is not i32");
1176
1177 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1178 return emitOpError("accumulator type for f16 tensor is not f16/f32");
1179
1180 if (inputETy.isBF16() && !accType.isF32())
1181 return emitOpError("accumulator type for bf16 tensor is not f32");
1182
1183 if (inputETy.isF32() && !accType.isF32())
1184 return emitOpError("accumulator type for f32 tensor is not f32");
1185
1186 if (inputETy != inputZpETy)
1187 return emitOpError("expect both input and its zero point are the same "
1188 "element type, got ")
1189 << inputETy << " and " << inputZpETy;
1190
1191 if (resultETy != outputZpETy)
1192 return emitOpError("expect both output and its zero point are the same "
1193 "element type, got ")
1194 << resultETy << " and " << outputZpETy;
1195
1196 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1197 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1198 return failure();
1199
1200 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1201 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1202 return failure();
1203
1204 return success();
1205}
1206
1207LogicalResult tosa::ClampOp::verify() {
1208 mlir::Type inputETy =
1209 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1210 if (auto quantType =
1211 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1212 inputETy = getStorageElementTypeFromQuantized(quantType);
1213 }
1214 mlir::Type outputETy =
1215 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1216 if (auto quantType =
1217 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1218 outputETy = getStorageElementTypeFromQuantized(quantType);
1219 }
1220 if (inputETy != outputETy)
1221 return emitOpError("input/output element types are incompatible.");
1222
1223 auto maxValAttr = getMaxValAttr();
1224 auto minValAttr = getMinValAttr();
1225
1226 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1227
1228 if (inputETy.isInteger(dataTypeBitWidth)) {
1229 // if input datatype is integer, check that the min_val/max_val attributes
1230 // are integer attributes, and that their type is the same as the input's
1231 // datatype
1232 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1233 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1234 if (!intMaxValAttr || !intMinValAttr ||
1235 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1236 (intMaxValAttr.getType() != inputETy))
1237 return emitOpError("min/max attributes types are incompatible with "
1238 "input/output element types.");
1239
1240 const bool isUnsigned = inputETy.isUnsignedInteger();
1241 const bool isBoolean = inputETy.isInteger(1);
1242 const APInt minVal = intMinValAttr.getValue();
1243 const APInt maxVal = intMaxValAttr.getValue();
1244 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1245 return emitOpError("expected min_val <= max_val, got min_val=")
1246 << minValAttr << ", max_val=" << maxValAttr;
1247 } else {
1248 // otherwise, input datatype is float, check that the min_val/max_val
1249 // attributes share the same type and that their type is the same as the
1250 // input's datatype
1251 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1252 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1253 if (!floatMaxValAttr || !floatMinValAttr ||
1254 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1255 (floatMaxValAttr.getType() != inputETy))
1256 return emitOpError("min/max attributes types are incompatible with "
1257 "input/output element types.");
1258
1259 const APFloat minVal = floatMinValAttr.getValue();
1260 const APFloat maxVal = floatMaxValAttr.getValue();
1261 if (minVal.isNaN() || maxVal.isNaN())
1262 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1263 << minValAttr << ", max_val=" << maxValAttr;
1264
1265 if (maxVal < minVal)
1266 return emitOpError("expected min_val <= max_val, got min_val=")
1267 << minValAttr << ", max_val=" << maxValAttr;
1268 }
1269
1270 return success();
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// TOSA Operator Quantization Builders.
1275//===----------------------------------------------------------------------===//
1276
1277/// This builder is called on all convolution operators except TransposeConv,
1278/// which has specialized output shape semantics. The builder also defines the
1279/// bitwidth of the output given the bit width of the input & weight content.
1281 Type outputType, Value input, Value weight,
1282 Value bias, DenseI64ArrayAttr pad,
1283 DenseI64ArrayAttr stride,
1284 DenseI64ArrayAttr dilation,
1285 TypeAttr accType) {
1286 auto zps = createZPsAsConst(builder, input, weight);
1287 result.addOperands({input, weight, bias, zps.first, zps.second});
1288 result.addAttribute("pad", pad);
1289 result.addAttribute("stride", stride);
1290 result.addAttribute("dilation", dilation);
1291 result.addAttribute("acc_type", accType);
1292 Type finalOutputType = outputType;
1293 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1294 if (quantAttr) {
1295 finalOutputType =
1296 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1297 }
1298 result.addTypes(finalOutputType);
1299}
1300
1301/// Handles tosa.transpose_conv2d which has outpad and output shape
1302/// attributes.
1303static void
1305 Type outputType, Value input, Value weight,
1306 Value bias, DenseI64ArrayAttr outpad,
1307 DenseI64ArrayAttr stride, TypeAttr accType) {
1308 auto zps = createZPsAsConst(builder, input, weight);
1309 result.addOperands({input, weight, bias, zps.first, zps.second});
1310 result.addAttribute("out_pad", outpad);
1311 result.addAttribute("stride", stride);
1312 result.addAttribute("acc_type", accType);
1313 Type finalOutputType = outputType;
1314 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1315 if (quantAttr) {
1316 finalOutputType =
1317 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1318 }
1319 result.addTypes(finalOutputType);
1320}
1321
1322/// The tosa.matmul op is also intended to be generated where a fully_connected
1323/// op must be constructed where the weight is not a constant. In this case,
1324/// the fully_connected op must be expressed using matmul.
1325/// TODO: Add link to the leglization document explaining this.
1327 OperationState &result, Type outputType,
1328 Value a, Value b) {
1329 auto zps = createZPsAsConst(builder, a, b);
1330 result.addOperands({a, b, zps.first, zps.second});
1331
1332 Type finalOutputType{outputType};
1333 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1334 auto eType = getStorageElementTypeOrSelf(a.getType());
1335 auto inputBits = eType.getIntOrFloatBitWidth();
1336
1337 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1338 assert(outputShapedType && "Output must be a shaped type");
1339
1340 IntegerType accElementType;
1341 if (inputBits == 16)
1342 accElementType = builder.getIntegerType(48);
1343 else
1344 accElementType = builder.getI32Type();
1345
1346 finalOutputType = outputShapedType.clone(accElementType);
1347 }
1348 result.addTypes(finalOutputType);
1349}
1350
1351/// Both the tosa.avg_pool2d and unary ops use the same
1352/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1353/// has additional parameters not part of the unary ops.
1354static void
1356 Type outputType, Value input,
1357 DenseArrayAttr kernel, DenseArrayAttr stride,
1358 DenseArrayAttr pad, TypeAttr accType) {
1359 const Location loc{result.location};
1360 int64_t inputZp{0};
1361 int64_t outputZp{0};
1362
1363 if (auto quantAttr =
1364 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1365 inputZp = quantAttr.getInputZp();
1366 outputZp = quantAttr.getOutputZp();
1367 }
1368 const std::optional<Value> inputZpOp =
1369 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1370 if (!inputZpOp) {
1371 (void)emitError(
1372 loc,
1373 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1374 }
1375 const std::optional<Value> outputZpOp =
1376 createZeroPointTensor(builder, loc, outputType, outputZp);
1377 if (!outputZpOp) {
1378 (void)emitError(loc, "Failed to create output zero point tensor for "
1379 "quantized AVG_POOL2D op");
1380 }
1381
1382 if (inputZpOp && outputZpOp) {
1383 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1384 } else {
1385 // failed to create one or more zero points above: just add input as
1386 // operands this will trigger error in building the op because of missing
1387 // zero points
1388 result.addOperands({input});
1389 }
1390 result.addAttribute("kernel", kernel);
1391 result.addAttribute("stride", stride);
1392 result.addAttribute("pad", pad);
1393 result.addAttribute("acc_type", accType);
1394 result.types.push_back(outputType);
1395}
1396
1397/// This builder is called on single-parameter negate operator
1398/// to construct input and output zero points based on their
1399/// types.
1401 OperationState &result, Type outputType,
1402 Value input) {
1403 const Location loc{result.location};
1404 int64_t input1Zp{0};
1405 int64_t outputZp{0};
1406 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1407 if (quantAttr) {
1408 input1Zp = quantAttr.getInputZp();
1409 outputZp = quantAttr.getOutputZp();
1410 }
1411 const std::optional<Value> input1ZpOp =
1412 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1413 if (!input1ZpOp) {
1414 (void)emitError(
1415 loc, "Failed to create input1 zero point for quantized NEGATE op");
1416 }
1417
1418 const std::optional<Value> outputZpOp =
1419 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1420 if (!outputZpOp) {
1421 (void)emitError(
1422 loc, "Failed to create output zero point for quantized NEGATE op");
1423 }
1424
1425 if (input1ZpOp && outputZpOp) {
1426 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1427 } else {
1428 // failed to create one or more zero points above: just add input as
1429 // operands. This will trigger error in building the op because of
1430 // missing zero points
1431 result.addOperands({input});
1432 }
1433
1434 result.types.push_back(outputType);
1435}
1436
1437/// This builder is called on TOSA pad operator that needs to create its own
1438/// OptionalAttr quantization_attr parameter to scale the padding values
1439/// correctly. No pad_const is interpreted as zero-padding.
1441 Type outputType, Value input,
1442 Value paddings) {
1443 const Location loc{result.location};
1444 int32_t zp{0};
1445 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1446 if (quantAttr) {
1447 zp = static_cast<int32_t>(quantAttr.getInputZp());
1448 }
1449 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1450 result.addOperands({input, paddings, padConstOp});
1451 result.types.push_back(outputType);
1452}
1453
1455 StringRef name, Type variableType,
1456 Attribute initialValue) {
1457 const Location loc{result.location};
1458 auto nameAttr = builder.getStringAttr(name);
1459
1460 auto shapedType = dyn_cast<ShapedType>(variableType);
1461 if (!shapedType) {
1462 (void)emitError(loc, "variable type must be a shaped type");
1463 return;
1464 }
1465 if (!shapedType.hasRank()) {
1466 (void)emitError(loc, "variable type must be a ranked type");
1467 return;
1468 }
1469
1470 auto elementType = shapedType.getElementType();
1471 auto elementTypeAttr = TypeAttr::get(elementType);
1472 ArrayRef<int64_t> shape = shapedType.getShape();
1473 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1474
1475 result.addAttribute("sym_name", nameAttr);
1476 result.addAttribute("var_shape", varShapeAttr);
1477 result.addAttribute("type", elementTypeAttr);
1478 result.addAttribute("initial_value", initialValue);
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// TOSA Operator Return Type Inference.
1483//===----------------------------------------------------------------------===//
1484static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1485 const int64_t dim2) {
1486 if (dim1 == 1)
1487 return dim2;
1488 if (dim2 == 1)
1489 return dim1;
1490
1491 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1492 return failure();
1493
1494 // Prefer static dimension over dynamic
1495 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1496}
1497
1498static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1499 SmallVector<int64_t> &outShape) {
1500 int64_t outRank = 0;
1501 for (int i = 0, e = operands.size(); i != e; ++i) {
1502 auto shape = operands.getShape(i);
1503 if (!shape.hasRank()) {
1504 // TODO(jennik): Update function to have better case handling for
1505 // invalid operands and for ranked tensors.
1506 return failure();
1507 }
1508 outRank = std::max<int64_t>(outRank, shape.getRank());
1509 }
1510
1511 outShape.resize(outRank, 1);
1512
1513 for (int i = 0, e = operands.size(); i != e; ++i) {
1514 auto shape = operands.getShape(i);
1515 auto rankDiff = outShape.size() - shape.getRank();
1516
1517 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1518 auto dim1 = outShape[i + rankDiff];
1519 auto dim2 = shape.getDimSize(i);
1520
1521 const FailureOr<int64_t> maybeResolvedDim =
1522 resolveBroadcastDim(dim1, dim2);
1523 if (failed(maybeResolvedDim))
1524 return failure();
1525 const int64_t resolvedDim = *maybeResolvedDim;
1526 outShape[i + rankDiff] = resolvedDim;
1527 }
1528 }
1529
1530 return success();
1531}
1532
1533LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1534 MLIRContext *context, ::std::optional<Location> location,
1535 ArgMaxOp::Adaptor adaptor,
1536 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1537 ShapeAdaptor inputShape(adaptor.getInput().getType());
1538 IntegerAttr axis = adaptor.getProperties().axis;
1539 int32_t axisVal = axis.getValue().getSExtValue();
1540
1541 if (!inputShape.hasRank()) {
1542 inferredReturnShapes.push_back(ShapedTypeComponents());
1543 return success();
1544 }
1545
1546 SmallVector<int64_t> outShape;
1547 outShape.reserve(inputShape.getRank() - 1);
1548 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1549 if (i == axisVal)
1550 continue;
1551 outShape.push_back(inputShape.getDimSize(i));
1552 }
1553
1554 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1555 return success();
1556}
1557
1558LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1559 MLIRContext *context, ::std::optional<Location> location,
1560 RFFT2dOp::Adaptor adaptor,
1561 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1562 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1563
1564 if (!inputShape.hasRank())
1565 return failure();
1566
1567 llvm::SmallVector<int64_t> outputShape;
1568 outputShape.resize(3, ShapedType::kDynamic);
1569 outputShape[0] = inputShape.getDimSize(0);
1570 outputShape[1] = inputShape.getDimSize(1);
1571 int64_t inWidth = inputShape.getDimSize(2);
1572
1573 // Note that we can support this calculation symbolically
1574 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1575 if (inWidth != ShapedType::kDynamic)
1576 outputShape[2] = inWidth / 2 + 1;
1577
1578 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1579 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1580
1581 return success();
1582}
1583
1584static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1585 const llvm::StringRef dimName) {
1586 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1587 if (!isPowerOfTwo)
1588 return op->emitOpError("expected ")
1589 << dimName << " to be a power of two, got " << dimSize;
1590
1591 return success();
1592}
1593
1594LogicalResult tosa::RFFT2dOp::verify() {
1595 const auto outputTypes = getResultTypes();
1596 if (failed(verifyCompatibleShapes(outputTypes)))
1597 return emitOpError("expected output shapes to match, got ") << outputTypes;
1598
1599 const auto inputType =
1600 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1601 if (!inputType)
1602 return success();
1603
1604 const int64_t height = inputType.getDimSize(1);
1605 if (ShapedType::isStatic(height) &&
1606 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1607 return failure();
1608
1609 const int64_t width = inputType.getDimSize(2);
1610 if (ShapedType::isStatic(width) &&
1611 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1612 return failure();
1613
1614 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1615 if (!outputType)
1616 return success();
1617
1618 // Batch and height input/output dimensions should match
1619 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1620 outputType.getShape().drop_back())))
1621 return emitOpError("expected batch and height dimensions of input/output "
1622 "to match, got input=")
1623 << inputType << " output=" << outputType;
1624
1625 // Output width dimension expected to be input_width / 2 + 1
1626 const int64_t outputWidth = outputType.getDimSize(2);
1627 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1628 (outputWidth != (width / 2) + 1))
1629 return emitOpError(
1630 "expected output width to be equal to input_width / 2 + 1, got ")
1631 << outputWidth;
1632
1633 return success();
1634}
1635
1636LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1637 MLIRContext *context, ::std::optional<Location> location,
1638 FFT2dOp::Adaptor adaptor,
1639 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1640 inferredReturnShapes.push_back(
1641 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1642 inferredReturnShapes.push_back(
1643 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1644 return success();
1645}
1646
1647LogicalResult tosa::FFT2dOp::verify() {
1648 const auto inputRealType =
1649 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1650 const auto inputImagType =
1651 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1652 if (!inputRealType || !inputImagType)
1653 return success();
1654
1655 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1656 return ShapedType::isDynamic(a) ? a : b;
1657 };
1658
1659 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1660 inputImagType.getDimSize(1));
1661 if (ShapedType::isStatic(height) &&
1662 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1663 return failure();
1664
1665 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1666 inputImagType.getDimSize(2));
1667 if (ShapedType::isStatic(width) &&
1668 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1669 return failure();
1670
1671 return success();
1672}
1673
1674LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1675 MLIRContext *context, ::std::optional<Location> location,
1676 ConcatOp::Adaptor adaptor,
1677 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1678 // Infer all dimension sizes by reducing based on inputs.
1679 const Properties &prop = adaptor.getProperties();
1680 int32_t axis = prop.axis.getValue().getSExtValue();
1681 llvm::SmallVector<int64_t> outputShape;
1682 bool hasRankedInput = false;
1683 for (auto operand : adaptor.getOperands()) {
1684 ShapeAdaptor operandShape(operand.getType());
1685 if (!operandShape.hasRank())
1686 continue;
1687
1688 // Copy the Operand's rank.
1689 if (!hasRankedInput)
1690 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1691
1692 // Copy shapes until the dim is non-dynamic.
1693 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1694 if (i == axis || operandShape.isDynamicDim(i))
1695 continue;
1696 if (outputShape[i] == ShapedType::kDynamic)
1697 outputShape[i] = operandShape.getDimSize(i);
1698 if (outputShape[i] != operandShape.getDimSize(i))
1699 return emitOptionalError(location,
1700 "Cannot concat tensors with different sizes"
1701 " on the non-axis dimension ",
1702 i);
1703 }
1704
1705 hasRankedInput = true;
1706 }
1707
1708 if (adaptor.getInput1().empty())
1709 return failure();
1710
1711 Type inputType =
1712 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1713 if (!hasRankedInput) {
1714 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1715 return success();
1716 }
1717
1718 // Determine the dimension size along the concatenation axis.
1719 int64_t concatDimSize = 0;
1720 for (auto operand : adaptor.getOperands()) {
1721 ShapeAdaptor operandShape(operand.getType());
1722
1723 // We need to know the length of the concatenation axis of all inputs to
1724 // determine the dimension size of the output shape.
1725 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1726 concatDimSize = ShapedType::kDynamic;
1727 break;
1728 }
1729
1730 concatDimSize += operandShape.getDimSize(axis);
1731 }
1732
1733 outputShape[axis] = concatDimSize;
1734
1735 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1736 return success();
1737}
1738
1739LogicalResult tosa::ConcatOp::verify() {
1740 // check that each input has same element type as output
1741 auto outType = getOutput().getType();
1742 const Operation::operand_range inputList = getInput1();
1743
1744 // Check there is at least one input
1745 if (inputList.empty())
1746 return emitOpError("expect at least one input");
1747
1748 if (!llvm::all_of(inputList, [&](auto input) {
1749 return succeeded(verifySameElementTypes(
1750 *this, /* inType = */ input.getType(), outType));
1751 })) {
1752 return failure();
1753 }
1754
1755 const int32_t axis = getAxis();
1756 ShapeAdaptor firstRankedInputShape = nullptr;
1757 for (const auto &input : inputList) {
1758 const Type inputType = input.getType();
1759 ShapeAdaptor currShape(inputType);
1760 if (currShape.hasRank()) {
1761 firstRankedInputShape = currShape;
1762 // Check axis is in expected range
1763 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1764 return emitOpError("expect axis to be within range 0 < axis < "
1765 "rank(input1[firstRankedTensorIdx]), got ")
1766 << axis;
1767 break;
1768 }
1769 }
1770
1771 const auto allOperandsHasRank = [](const Value input) {
1772 return ShapeAdaptor(input.getType()).hasRank();
1773 };
1774 if (llvm::all_of(inputList, allOperandsHasRank)) {
1775 const int64_t firstInputRank = firstRankedInputShape.getRank();
1776
1777 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1778 const ShapeAdaptor inputShape(input.getType());
1779 const int64_t inputRank = inputShape.getRank();
1780 const size_t operandNum = index + 1;
1781
1782 // Check that each operand has the same rank
1783 if (inputRank != firstInputRank)
1784 return emitOpError(
1785 "expect all operands to have the same rank, but got ")
1786 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1787 << operandNum;
1788
1789 // Check non-axis dims match
1790 for (int i = 0; i < inputRank; i++) {
1791 const int64_t inputDim = inputShape.getDimSize(i);
1792 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1793 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1794 inputShape.isDynamicDim(i))
1795 continue;
1796 if (inputDim != firstInputDim)
1797 return emitOpError("expect all operand shapes to have the same sizes "
1798 "on non-axis dimensions, but got ")
1799 << inputDim << " vs " << firstInputDim << " at index " << i
1800 << " on operands 0 and " << operandNum;
1801 }
1802 }
1803
1804 const ShapeAdaptor outputShape(outType);
1805 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1806 return emitOpError("expect output rank to match inputs rank, got ")
1807 << outputShape.getRank() << " vs " << firstInputRank;
1808
1809 // ERROR_IF(axis_sum != shape[axis]);
1810 int64_t axisSum = 0;
1811 for (const auto &input : inputList) {
1812 const ShapeAdaptor inputShape(input.getType());
1813 if (inputShape.isDynamicDim(axis)) {
1814 // make axisSum negative to indicate invalid value
1815 axisSum = -1;
1816 break;
1817 }
1818 axisSum += inputShape.getDimSize(axis);
1819 }
1820
1821 if (axisSum >= 0 && outputShape.hasRank() &&
1822 !outputShape.isDynamicDim(axis) &&
1823 axisSum != outputShape.getDimSize(axis))
1824 return emitOpError("requires sum of axis dimensions of input1 "
1825 "equal to output axis dimension, got ")
1826 << axisSum << " and " << outputShape.getDimSize(axis);
1827 }
1828
1829 return success();
1830}
1831
1832LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1833 MLIRContext *context, ::std::optional<Location> location,
1834 ValueShapeRange operands, DictionaryAttr attributes,
1835 OpaqueProperties properties, RegionRange regions,
1836 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1837 auto elementType = IntegerType::get(context, /*width=*/1);
1838
1840 if (resolveBroadcastShape(operands, outShape).failed()) {
1841 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1842 return success();
1843 }
1844
1845 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1846 return success();
1847}
1848
1849bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1850 if (l.size() != r.size() || l.size() != 1)
1851 return false;
1852 return succeeded(verifyCompatibleShape(l[0], r[0]));
1853}
1854
1855LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1856 MLIRContext *context, ::std::optional<Location> location,
1857 MatMulOp::Adaptor adaptor,
1858 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1859 ShapeAdaptor lhsShape(adaptor.getA().getType());
1860 ShapeAdaptor rhsShape(adaptor.getB().getType());
1861
1862 // All shapes are dynamic.
1863 SmallVector<int64_t> outShape;
1864 outShape.resize(3, ShapedType::kDynamic);
1865
1866 if (lhsShape.hasRank()) {
1867 outShape[0] = lhsShape.getDimSize(0);
1868 outShape[1] = lhsShape.getDimSize(1);
1869 }
1870
1871 if (rhsShape.hasRank()) {
1872 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1873 : outShape[0];
1874 outShape[2] = rhsShape.getDimSize(2);
1875 }
1876
1877 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1878 return success();
1879}
1880
1881LogicalResult MatMulOp::verify() {
1882 auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1883 auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1884
1885 // Must be shaped tensor types
1886 if (!aType)
1887 return emitOpError("expect a shaped tensor for input a, got ")
1888 << getA().getType();
1889
1890 if (!bType)
1891 return emitOpError("expect a shaped tensor for input b, got ")
1892 << getB().getType();
1893
1894 auto aElementType = aType.getElementType();
1895 auto bElementType = bType.getElementType();
1896
1897 auto aQuantizedEType =
1898 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1899 auto bQuantizedEType =
1900 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1901
1902 if (aQuantizedEType || bQuantizedEType) {
1903 if (!aQuantizedEType || !bQuantizedEType) {
1904 return emitOpError("expect operands to be both quantized or both not "
1905 "quantized, got ")
1906 << aElementType << " and " << bElementType;
1907 }
1908 // both a and b have quantized element types
1909 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1910 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1911 if (aQuantWidth != bQuantWidth) {
1912 return emitOpError("expect quantized operands to have same widths, got ")
1913 << aQuantWidth << " and " << bQuantWidth;
1914 }
1915 }
1916
1917 // check a_zp and b_zp
1918 auto aEType = getStorageElementTypeOrSelf(aType);
1919 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1920 if (aEType != aZpEType) {
1921 return emitOpError("expect input a and a_zp have the same "
1922 "element type, got ")
1923 << aEType << " and " << aZpEType;
1924 }
1925
1926 auto bEType = getStorageElementTypeOrSelf(bType);
1927 auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1928 if (bEType != bZpEType) {
1929 return emitOpError("expect input b and b_zp have the same "
1930 "element type, got ")
1931 << bEType << " and " << bZpEType;
1932 }
1933
1934 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1935 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1936 return failure();
1937
1938 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1939 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1940 return failure();
1941
1942 return success();
1943}
1944
1945LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1946 MLIRContext *context, ::std::optional<Location> location,
1947 MatmulTBlockScaledOp::Adaptor adaptor,
1948 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1949 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1950
1951 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1952 if (aDataShape.hasRank()) {
1953 outShape[0] = aDataShape.getDimSize(0);
1954 outShape[1] = aDataShape.getDimSize(1);
1955 }
1956
1957 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1958 if (aScaleShape.hasRank()) {
1959 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1960 : outShape[0];
1961 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1962 : outShape[1];
1963 }
1964
1965 // If B batch size is 1, it is broadcast across A's batch size
1966 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1967 if (bDataShape.hasRank()) {
1968 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1969 if (bDataBatchSize != 1)
1970 outShape[0] =
1971 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1972 outShape[2] = bDataShape.getDimSize(1);
1973 }
1974
1975 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1976 if (bScaleShape.hasRank()) {
1977 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1978 if (bScaleBatchSize != 1)
1979 outShape[0] =
1980 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1981 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1982 : outShape[2];
1983 }
1984
1985 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1986 return success();
1987}
1988
1989LogicalResult MatmulTBlockScaledOp::verify() {
1990 // Verify same input data types
1991 const Type aDataType = getAData().getType();
1992 const Type bDataType = getBData().getType();
1993 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1994 "B_data")))
1995 return failure();
1996
1997 // Verify input shape compatibility
1998 int64_t N = ShapedType::kDynamic;
1999 int64_t D = ShapedType::kDynamic;
2000 int64_t H = ShapedType::kDynamic;
2001 int64_t W = ShapedType::kDynamic;
2002 int64_t C = ShapedType::kDynamic;
2003 int64_t multiplesOfC = ShapedType::kDynamic;
2004
2005 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2006 if (aDataShape.hasRank()) {
2007 N = aDataShape.getDimSize(0);
2008 H = aDataShape.getDimSize(1);
2009 C = aDataShape.getDimSize(2);
2010 }
2011
2012 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2013 if (aScaleShape.hasRank()) {
2014 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2015 "a_scale", "batch")) ||
2016 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2017 "a_scale", "height")))
2018 return failure();
2019 multiplesOfC = aScaleShape.getDimSize(2);
2020 }
2021
2022 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2023 if (bDataShape.hasRank()) {
2024 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2025 "b_data", "batch")) ||
2026 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2027 "b_data", "channels")))
2028 return failure();
2029 W = bDataShape.getDimSize(1);
2030 }
2031
2032 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2033 if (bScaleShape.hasRank()) {
2034 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2035 "b_scale", "batch")) ||
2036 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2037 "b_scale", "width")) ||
2038 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2039 bScaleShape.getDimSize(2), "b_scale",
2040 "C/block_size")))
2041 return failure();
2042 }
2043
2044 // Verify batch size is broadcast compatible
2045 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2046 return emitOpError("expect B matrix batch size to be broadcast compatible "
2047 "with A, got D=")
2048 << D << " vs N=" << N;
2049
2050 // Verify C is a multiple of block size
2051 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2052 if (ShapedType::isStatic(C) && C % blockSize != 0)
2053 return emitOpError("expect C to be a multiple of block size, got C=")
2054 << C << ", block_size=" << blockSize;
2055
2056 // Verify multiplesOfC is C / block size
2057 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2058 multiplesOfC != C / blockSize)
2059 return emitOpError(
2060 "expect scale operands dimension 2 to equal C/block_size (")
2061 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2062
2063 // Verify output shape
2064 N = ShapedType::isDynamic(N) ? D : N;
2065 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2066 const auto outputType = cast<ShapedType>(getResult().getType());
2067 if (outputType.hasRank() &&
2068 failed(
2069 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2070 InFlightDiagnostic opError = emitOpError("expected output shape ");
2071 auto stringifyDim = [&](int64_t d) {
2072 if (ShapedType::isDynamic(d))
2073 opError << "?";
2074 else
2075 opError << d;
2076 };
2077 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2078 opError << " to be compatible with expected output shape ";
2079 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2080 return opError;
2081 }
2082
2083 return success();
2084}
2085
2086LogicalResult tosa::PadOp::inferReturnTypeComponents(
2087 MLIRContext *context, ::std::optional<Location> location,
2088 PadOp::Adaptor adaptor,
2089 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2090 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2091 auto paddingRank =
2092 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2093 SmallVector<int64_t> outputShape;
2094
2095 // If the input rank is unknown, we can infer the output rank using the
2096 // padding shape's rank divided by 2.
2097 if (!inputShape.hasRank()) {
2098 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2099 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2100 return success();
2101 }
2102
2103 SmallVector<int64_t> paddingValues;
2104 // If the paddings value is not a constant, all dimensions must be dynamic.
2105 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2106 paddingValues)) {
2107 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2108 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2109 return success();
2110 }
2111
2112 outputShape.reserve(inputShape.getRank());
2113 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2114 if (inputShape.isDynamicDim(i)) {
2115 outputShape.push_back(ShapedType::kDynamic);
2116 continue;
2117 }
2118 auto padFront = paddingValues[i * 2];
2119 auto padBack = paddingValues[i * 2 + 1];
2120 if (padFront < 0 || padBack < 0) {
2121 // if either padding for dim i is -1, output dim is unknown
2122 outputShape.push_back(ShapedType::kDynamic);
2123 continue;
2124 }
2125
2126 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2127 }
2128
2129 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2130 return success();
2131}
2132
2133LogicalResult tosa::PadOp::verify() {
2134 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2135 /* outType = */ getOutput().getType())
2136 .failed()) {
2137 return failure();
2138 }
2139
2140 if (auto padConst = getPadConst()) {
2141 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2142 /* outType = */ getOutput().getType())
2143 .failed()) {
2144 return failure();
2145 }
2146 }
2147
2148 RankedTensorType inputType =
2149 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2150 RankedTensorType outputType =
2151 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2152 if (!inputType || !outputType)
2153 return success();
2154
2155 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2156 "output")))
2157 return failure();
2158
2159 auto inputRank = inputType.getRank();
2160 DenseIntElementsAttr paddingAttr;
2161 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2162 return success();
2163
2164 auto paddingValues = paddingAttr.getValues<APInt>();
2165 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2166 return emitOpError() << "padding tensor must have " << inputRank
2167 << " * 2 = " << inputRank * 2 << " elements, but got "
2168 << paddingValues.size();
2169
2170 auto inputShape = inputType.getShape();
2171 auto outputShape = outputType.getShape();
2172
2173 for (int64_t i = 0; i < inputRank; ++i) {
2174 int64_t padStart = paddingValues[i * 2].getSExtValue();
2175 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2176
2177 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2178 return emitOpError()
2179 << "invalid padding values at dimension " << i
2180 << ": values must be non-negative or -1 for dynamic padding, got ["
2181 << padStart << ", " << padEnd << "]";
2182 }
2183
2184 // Skip shape verification for dynamic input/output
2185 if (inputShape[i] == ShapedType::kDynamic ||
2186 outputShape[i] == ShapedType::kDynamic)
2187 continue;
2188
2189 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2190 return emitOpError() << "mismatch in output shape at dimension " << i
2191 << ": expected " << inputShape[i] << " + "
2192 << padStart << " + " << padEnd << " = "
2193 << (inputShape[i] + padStart + padEnd)
2194 << ", but got " << outputShape[i];
2195 }
2196 }
2197
2198 return success();
2199}
2200
2201LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2202 MLIRContext *context, ::std::optional<Location> location,
2203 SliceOp::Adaptor adaptor,
2204 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2205
2206 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2209
2210 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2211 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2212 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2213 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2214 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2215 return success();
2216 }
2217
2218 // if size[i] is -1, all remaining elements in dimension i are included
2219 // in the slice, similar to TF.
2220 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2221 // initialize outputShape to all unknown
2222 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2223 if (inputShape.hasRank()) {
2224 for (size_t i = 0; i < size.size(); i++) {
2225 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2226 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2227 start[i] < inputShape.getDimSize(i))) {
2228 // size[i] is not 0 and not < -1, and start[i] is in valid range
2229 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2230 // input shape has unknown dim[i] - only valid if size[i] > 0
2231 if (size[i] > 0) {
2232 outputShape[i] = size[i];
2233 }
2234 } else {
2235 // input shape has known dim[i]
2236 if (size[i] == -1) {
2237 outputShape[i] = inputShape.getDimSize(i) - start[i];
2238 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2239 // start[i] + size[i] is within bound of input shape's dim[i]
2240 outputShape[i] = size[i];
2241 }
2242 }
2243 }
2244 }
2245 } else {
2246 outputShape = convertToMlirShape(size);
2247 }
2248 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2249 return success();
2250}
2251
2252LogicalResult tosa::SliceOp::verify() {
2253 const Value input = getInput1();
2254 const Value output = getOutput();
2255 if (verifySameElementTypes(*this, /* inType = */ input.getType(),
2256 /* outType = */ output.getType())
2257 .failed())
2258 return failure();
2259
2260 const Value start = getStart();
2261 const Value size = getSize();
2262 const ShapeAdaptor inputShape(input.getType());
2263 const ShapeAdaptor outputShape(output.getType());
2264
2265 if (inputShape.hasRank()) {
2266 const auto inputRank = inputShape.getRank();
2267 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2268 return emitOpError(
2269 "expect input1 and output to have the same ranks, got ")
2270 << inputRank << " and " << outputShape.getRank();
2271
2272 const auto startShapeRank =
2273 llvm::cast<tosa::shapeType>(start.getType()).getRank();
2274 if (inputRank != startShapeRank)
2275 return emitOpError("length of start is not equal to rank of input shape");
2276
2277 const auto sizeShapeRank =
2278 llvm::cast<tosa::shapeType>(size.getType()).getRank();
2279 if (inputRank != sizeShapeRank)
2280 return emitOpError("length of size is not equal to rank of input shape");
2281 }
2282
2283 SmallVector<int64_t> startValues;
2284 tosa::getConstShapeValues(start.getDefiningOp(), startValues);
2285 if (startValues.size()) {
2286 if (llvm::any_of(startValues, [](const int64_t v) {
2287 return v < 0 && v != kInferableDimSize;
2288 }))
2289 return emitOpError("start values must be non-negative, got [")
2290 << startValues << "]";
2291 }
2292
2293 SmallVector<int64_t> sizeValues;
2294 if (!tosa::getConstShapeValues(size.getDefiningOp(), sizeValues))
2295 return success();
2296
2297 if (llvm::any_of(sizeValues, [](const int64_t v) {
2298 return v <= 0 && v != kInferableDimSize;
2299 }))
2300 return emitOpError("size values must be > 0, got [") << sizeValues << "]";
2301 if (outputShape.hasRank()) {
2302 SmallVector<int64_t> outputDims;
2303 outputShape.getDims(outputDims);
2304 const bool hasNoInferableDims = llvm::all_of(
2305 sizeValues, [](const int64_t v) { return v != kInferableDimSize; });
2306 if (hasNoInferableDims &&
2307 failed(verifyCompatibleShape(outputDims, sizeValues)))
2308 return emitOpError("expected output shape to match size values, got ")
2309 << output.getType() << " vs [" << sizeValues << "]";
2310 }
2311
2312 if (inputShape.hasRank() && startValues.size()) {
2313 SmallVector<int64_t> inputDims;
2314 inputShape.getDims(inputDims);
2315 for (const auto &[index, vals] :
2316 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2317 const auto &[start, size, inputDim] = vals;
2318 if (start == kInferableDimSize || size == kInferableDimSize ||
2319 ShapedType::isDynamic(inputDim))
2320 continue;
2321 if (start + size > inputDim)
2322 return emitOpError("start + size must be less than or equal to input "
2323 "dimension size, got start=")
2324 << start << ", size=" << size
2325 << " vs input dim size=" << inputDim << " at dimension "
2326 << index;
2327 }
2328 }
2329
2330 return success();
2331}
2332
2333LogicalResult tosa::MulOp::inferReturnTypeComponents(
2334 MLIRContext *context, ::std::optional<Location> location,
2335 ValueShapeRange operands, DictionaryAttr attributes,
2336 OpaqueProperties properties, RegionRange regions,
2337 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2338 // mul op's output shape only depend on input1 and input2, not on shift
2339 ValueShapeRange twoInputs = operands.drop_back();
2341 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2342 inferredReturnShapes.push_back(ShapedTypeComponents());
2343 } else {
2344 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2345 }
2346 return success();
2347}
2348
2349LogicalResult tosa::MulOp::verify() {
2350 const Value output = getOutput();
2351 auto resElemType = getElementTypeOrSelf(output);
2352
2353 // Verify if the element type among operands and result match tosa
2354 // specification.
2355 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2356 IntegerType lhsIntType =
2357 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2358 IntegerType rhsIntType =
2359 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2360 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2361 return emitOpError("requires the same element type for all operands");
2362
2363 // Though the spec requires the element type of result to be i32, a more
2364 // relaxed way is provided at dialect level for easier cooperating with
2365 // other dialects.
2366 if (lhsIntType.getWidth() > resIntType.getWidth())
2367 return emitOpError("invalid data type size for operands or result");
2368
2369 } else {
2370 // For other supported type, the spec requires requires the same element
2371 // type for all operands (excludes `shift` operand) and results.
2372 for (int i = 0; i < 2; ++i) {
2373 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2374 return emitOpError(
2375 "requires the same element type for all operands and results");
2376 }
2377
2378 // verify shift has value 0 for non-integer types
2379 ElementsAttr shiftElem;
2380 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2381 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2382 if (shift != 0) {
2383 return emitOpError() << "require shift to be 0 for float type";
2384 }
2385 }
2386 }
2387
2388 // Verify the op has same ranks for all main operands (excludes extra operands
2389 // such as shift of mul op, so this is the only difference with the built-in
2390 // `SameOperandsAndResultRank` trait) and results types, if known.
2391 TypeRange operandTypes = getOperandTypes();
2392 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2393 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2394
2395 const bool aHasRank = aType.hasRank();
2396 const bool bHasRank = bType.hasRank();
2397 if (aHasRank && bHasRank) {
2398 const int64_t aRank = aType.getRank();
2399 const int64_t bRank = bType.getRank();
2400 if (aRank != bRank)
2401 return emitOpError("a and b operands don't have matching ranks, got ")
2402 << aRank << " and " << bRank;
2403
2404 // check for broadcast compatible shapes
2405 SmallVector<int64_t> resultShape;
2407 aType.getShape(), bType.getShape(), resultShape))
2408 return emitOpError("a and b operands don't have broadcast-compatible "
2409 "shapes, got ")
2410 << aType << " and " << bType;
2411 }
2412
2413 ShapedType resultType = cast<ShapedType>(output.getType());
2414 if (!resultType.hasRank())
2415 return success();
2416
2417 const int64_t resultRank = resultType.getRank();
2418 if (aHasRank && resultRank != aType.getRank())
2419 return emitOpError("result type has different rank than a, got ")
2420 << resultRank << " vs " << aType.getRank();
2421 if (bHasRank && resultRank != bType.getRank())
2422 return emitOpError("result type has different rank than b, got ")
2423 << resultRank << " vs " << bType.getRank();
2424
2425 return success();
2426}
2427
2428LogicalResult tosa::TableOp::inferReturnTypeComponents(
2429 MLIRContext *context, ::std::optional<Location> location,
2430 TableOp::Adaptor adaptor,
2431 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2432 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2433
2434 if (!inputShape.hasRank()) {
2435 inferredReturnShapes.push_back(ShapedTypeComponents());
2436 return success();
2437 }
2438
2439 inferredReturnShapes.resize(1);
2440 inputShape.getDims(inferredReturnShapes[0]);
2441 return success();
2442}
2443
2444LogicalResult tosa::TableOp::verify() {
2445 const TensorType inputType = getInput1().getType();
2446 const TensorType outputType = getOutput().getType();
2447
2448 if (!inputType.hasRank() || !outputType.hasRank())
2449 return success();
2450
2451 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2452 "result")))
2453 return failure();
2454
2455 auto inputDims = inputType.getShape();
2456 auto outputDims = outputType.getShape();
2457 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2458 int64_t dim = it.index();
2459 auto [inputDim, outputDim] = it.value();
2460 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2461 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2462 << " doesn't match dim(input, " << dim
2463 << ") = " << inputDim;
2464 }
2465 }
2466 return success();
2467}
2468
2469LogicalResult
2470tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2471 // Multiples must be constants.
2472 DenseIntElementsAttr multiplesAttr;
2473 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2474 return failure();
2475 multiples =
2476 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2477 [](const APInt &val) { return val.getSExtValue(); });
2478 return success();
2479}
2480
2481LogicalResult tosa::TileOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2483 TileOp::Adaptor adaptor,
2484 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2485 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2486 SmallVector<int64_t> multiples;
2487 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2488 multiples)) {
2489 auto rank =
2490 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2491 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2492 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2493 return success();
2494 } else {
2495 multiples = convertToMlirShape(multiples);
2496 }
2497
2498 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2499 SmallVector<int64_t> outputShape;
2500 if (!inputShape.hasRank()) {
2501 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2502 inferredReturnShapes.push_back(
2503 ShapedTypeComponents(outputShape, inputType));
2504 return success();
2505 } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2506 return failure();
2507
2508 // Any non dynamic dimension can be multiplied to a known size.
2509 outputShape.reserve(multiples.size());
2510 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2511 if (multiples[i] == ShapedType::kDynamic) {
2512 outputShape.push_back(ShapedType::kDynamic);
2513 } else {
2514 int64_t dim = inputShape.getDimSize(i);
2515 if (dim != ShapedType::kDynamic)
2516 dim *= multiples[i];
2517 outputShape.push_back(dim);
2518 }
2519 }
2520
2521 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2522 return success();
2523}
2524
2525LogicalResult tosa::TileOp::verify() {
2526 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2527 /* outType = */ getOutput().getType())
2528 .failed()) {
2529 return failure();
2530 }
2531 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2532 ShapedType outputType = llvm::cast<ShapedType>(getType());
2533
2534 shapeType multiplesType =
2535 llvm::cast<tosa::shapeType>(getMultiples().getType());
2536
2537 auto multiplesRank = multiplesType.getRank();
2538
2539 if (inputType.hasRank()) {
2540 if (inputType.getRank() != multiplesRank)
2541 return emitOpError("expect 'multiples' to have rank ")
2542 << inputType.getRank() << " but got " << multiplesRank << ".";
2543 if (outputType.hasRank() &&
2544 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2545 "output")))
2546 return failure();
2547 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2548 return emitOpError("expect 'multiples' array to have length ")
2549 << outputType.getRank() << " but got " << multiplesRank << ".";
2550
2551 SmallVector<int64_t> multiples;
2552 if (getConstantMultiples(multiples).succeeded() &&
2553 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2554 return emitOpError(
2555 "expect element of 'multiples' to be positive integer or -1.");
2556
2557 return success();
2558}
2559
2560bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2561 if (l.size() != r.size() || l.size() != 1)
2562 return false;
2563 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2564}
2565
2566LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2567 MLIRContext *context, ::std::optional<Location> location,
2568 ReshapeOp::Adaptor adaptor,
2569 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2570 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2571 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2572 llvm::SmallVector<int64_t> newShapeValue;
2573 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2574 newShapeValue)) {
2575 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2576 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2577 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2578 return success();
2579 } else {
2580 newShapeValue = convertToMlirShape(newShapeValue);
2581 }
2582
2583 // We cannot infer from the total number of elements so we must take the
2584 // shape attribute as exact.
2585 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2586 inferredReturnShapes.push_back(
2587 ShapedTypeComponents(newShapeValue, inputType));
2588 return success();
2589 }
2590
2591 // Determine the number of elements covered by the slice of all static
2592 // dimensions. This allows us to infer the length of the remaining dynamic
2593 // dimension.
2594 int64_t numElements = inputShape.getNumElements();
2595 int64_t staticMul = 1;
2596 for (auto val : newShapeValue) {
2597 if (ShapedType::isStatic(val)) {
2598 staticMul *= val;
2599 }
2600 }
2601
2602 // Determine the length of the dynamic dimension.
2603 for (auto &val : newShapeValue) {
2604 if (ShapedType::isDynamic(val))
2605 val = numElements / staticMul;
2606 }
2607
2608 inferredReturnShapes.push_back(
2609 ShapedTypeComponents(newShapeValue, inputType));
2610 return success();
2611}
2612
2613llvm::LogicalResult tosa::ReshapeOp::verify() {
2614 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2615 /* outType = */ getOutput().getType())
2616 .failed()) {
2617 return failure();
2618 }
2619 TensorType inputType = getInput1().getType();
2620
2621 SmallVector<int64_t> shapeValues;
2622 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2623 // skip following checks if shape is not constant
2624 return mlir::success();
2625 }
2626
2627 int missingDims = llvm::count(shapeValues, kInferableDimSize);
2628 if (missingDims > 1)
2629 return emitOpError() << "expected at most one target dimension to be "
2631
2632 const auto outputType = dyn_cast<RankedTensorType>(getType());
2633 if (!outputType)
2634 return success();
2635
2636 if ((int64_t)shapeValues.size() != outputType.getRank())
2637 return emitOpError() << "new shape does not match result rank";
2638
2639 for (auto [newShapeDim, outputShapeDim] :
2640 zip(shapeValues, outputType.getShape())) {
2641 if (newShapeDim != kInferableDimSize &&
2642 newShapeDim != ShapedType::kDynamic &&
2643 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2644 return emitOpError() << "new shape is inconsistent with result shape";
2645
2646 if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
2647 return emitOpError() << "new shape has invalid tensor dimension size "
2648 << newShapeDim;
2649 }
2650
2651 if (inputType.hasStaticShape()) {
2652 int64_t inputElementsNum = inputType.getNumElements();
2653 if (outputType.hasStaticShape()) {
2654 int64_t outputElementsNum = outputType.getNumElements();
2655 if (inputElementsNum != outputElementsNum) {
2656 return emitOpError() << "cannot reshape " << inputElementsNum
2657 << " elements into " << outputElementsNum;
2658 }
2659 }
2660
2661 int64_t newShapeElementsNum =
2662 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2663 return (dim > 0) ? acc * dim : acc;
2664 });
2665 bool isStaticNewShape =
2666 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2667 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2668 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2669 return emitOpError() << "cannot reshape " << inputElementsNum
2670 << " elements into " << newShapeElementsNum;
2671 }
2672 }
2673
2674 return mlir::success();
2675}
2676
2677// return failure if val is not a constant
2678// set zp to -1 if val is non-zero float or val is not integer nor float
2679// otherwise set zp to val's constant value
2680static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2681 ElementsAttr zpAttr;
2682 if (!matchPattern(val, m_Constant(&zpAttr))) {
2683 return failure();
2684 }
2685
2686 Type zpElemType = zpAttr.getElementType();
2687
2688 if (llvm::isa<FloatType>(zpElemType)) {
2689 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2690 return 0;
2691 }
2692 // return non-zero value to trigger error check
2693 return -1;
2694 }
2695
2696 if (llvm::isa<IntegerType>(zpElemType)) {
2697 if (signExtend)
2698 return zpAttr.getValues<APInt>()[0].getSExtValue();
2699 else
2700 return zpAttr.getValues<APInt>()[0].getZExtValue();
2701 }
2702
2703 // return non-zero value to trigger error check
2704 return -1;
2705}
2706
2707template <typename T>
2708static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2709 const std::string &operand) {
2710 Type zpElemType = getElementTypeOrSelf(val);
2711
2712 if (!zpElemType.isInteger(8) && zp != 0) {
2713 // convert operand to lower case for error message
2714 std::string lower = operand;
2715 llvm::transform(lower, lower.begin(), ::tolower);
2716 return op.emitOpError()
2717 << lower << " zero point must be zero for non-int8 integer types";
2718 }
2719
2720 return success();
2721}
2722
2723static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2724 const int64_t &zp,
2725 const std::string &operand) {
2726 bool isInputZp = (operand == "Input");
2727
2728 bool tensorUnsigned =
2729 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2730 StringRef tensorName = isInputZp ? "input" : "output";
2731
2732 Type zpElemType = getElementTypeOrSelf(zpVal);
2733
2734 if (zp != 0) {
2735 if (!zpElemType.isInteger(8) &&
2736 !(zpElemType.isInteger(16) && tensorUnsigned)) {
2737 return op.emitOpError()
2738 << "expect " << tensorName << "_zp of 0, got " << zp;
2739 }
2740 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2741 return op.emitOpError() << "expect " << tensorName
2742 << "_zp of 0 or 32768 for unsigned int16 "
2743 << tensorName << ", got " << zp;
2744 }
2745 }
2746
2747 return success();
2748}
2749
2750#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2751 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2752 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2753 } \
2754 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2755 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2756 }
2757
2758ZERO_POINT_HELPER(Conv2DOp, Input, true)
2759ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2760ZERO_POINT_HELPER(Conv3DOp, Input, true)
2761ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2762ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2763ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2764ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2765ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2766ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2767ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2768ZERO_POINT_HELPER(MatMulOp, A, true)
2769ZERO_POINT_HELPER(MatMulOp, B, true)
2770ZERO_POINT_HELPER(NegateOp, Input1, true)
2771ZERO_POINT_HELPER(NegateOp, Output, true)
2772ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2773ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2774#undef ZERO_POINT_HELPER
2775
2776LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2777 MLIRContext *context, ::std::optional<Location> location,
2778 TransposeOp::Adaptor adaptor,
2779 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2780 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2781
2782 // If input rank and permutation length is unknown, the output rank is
2783 // unknown.
2784 if (!inputShape.hasRank()) {
2785 inferredReturnShapes.push_back(ShapedTypeComponents());
2786 return success();
2787 }
2788
2789 const auto inputRank = inputShape.getRank();
2790
2791 // This would imply the number of permutations does not match the rank of
2792 // the input which is illegal.
2793 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2794 return failure();
2795 }
2796
2797 SmallVector<int64_t> outputShape;
2798 // Rank-0 means no permutations matter.
2799 if (inputRank == 0) {
2800 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2801 return success();
2802 }
2803
2804 // Check whether the input dimensions are all the same.
2805 bool allTheSame = true;
2806 for (int i = 1, s = inputRank; i < s; i++) {
2807 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2808 allTheSame = false;
2809 break;
2810 }
2811 }
2812
2813 // If all of the input dimensions are the same we don't care about the
2814 // permutation.
2815 if (allTheSame) {
2816 outputShape.resize(inputRank, inputShape.getDimSize(0));
2817 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2818 return success();
2819 }
2820
2821 outputShape.resize(inputRank, ShapedType::kDynamic);
2822
2823 // Constant permutation values must be within the input rank.
2824 if (llvm::any_of(adaptor.getPerms(),
2825 [inputRank](const auto i) { return i >= inputRank; }))
2826 return failure();
2827
2828 outputShape.reserve(inputRank);
2829 for (int i = 0, s = inputRank; i < s; i++) {
2830 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2831 }
2832
2833 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2834 return success();
2835}
2836
2837LogicalResult tosa::TransposeOp::verify() {
2838 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2839 /* outType = */ getOutput().getType())
2840 .failed()) {
2841 return failure();
2842 }
2843
2844 const ShapeAdaptor inputShape(getInput1().getType());
2845 const ShapeAdaptor outputShape(getOutput().getType());
2846
2847 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2848
2849 if (inputShape.hasRank() &&
2850 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2851 return emitOpError() << "expected perms attribute to have size "
2852 << inputShape.getRank()
2853 << " (input rank) but got size "
2854 << constantPerms.size();
2855
2856 if (inputShape.hasRank() && outputShape.hasRank() &&
2857 inputShape.getRank() != outputShape.getRank())
2858 return emitOpError()
2859 << "expected input tensor rank to equal result tensor rank";
2860
2861 if (outputShape.hasRank() &&
2862 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2863 return emitOpError() << "expected perms attribute to have size "
2864 << outputShape.getRank()
2865 << " (output rank) but got size "
2866 << constantPerms.size();
2867
2868 if (!llvm::all_of(constantPerms,
2869 [&constantPerms](int32_t s) {
2870 return s >= 0 &&
2871 static_cast<size_t>(s) < constantPerms.size();
2872 }) ||
2873 !isPermutationVector(llvm::map_to_vector(
2874 constantPerms, [](int32_t v) -> int64_t { return v; })))
2875 return emitOpError() << "expected valid permutation indices";
2876
2877 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2878 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2879 inputShape.getNumElements() != outputShape.getNumElements())
2880 return emitOpError() << "expected input1 and output to have same numbers "
2881 "of elements, got "
2882 << inputShape.getNumElements() << " and "
2883 << outputShape.getNumElements();
2884
2885 // Verify that the types of the input and output tensors are properly
2886 // permuted.
2887 if (inputShape.hasRank() && outputShape.hasRank()) {
2888 for (auto i = 0; i < outputShape.getRank(); i++) {
2889 if (inputShape.isDynamicDim(constantPerms[i]) ||
2890 outputShape.isDynamicDim(i))
2891 continue;
2892
2893 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2894 return emitOpError()
2895 << "expected output tensor dim " << i << " to match "
2896 << "input dim " << constantPerms[i] << " with value of "
2897 << inputShape.getDimSize(constantPerms[i]);
2898 }
2899 }
2900
2901 return success();
2902}
2903
2904LogicalResult TransposeOp::reifyResultShapes(
2905 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2906
2907 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2908
2909 Value input = getInput1();
2910 auto inputType = cast<TensorType>(input.getType());
2911
2912 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2913 for (auto dim : transposePerms) {
2914 int32_t dimInInput = transposePerms[dim];
2915 if (inputType.isDynamicDim(dimInInput))
2916 returnedDims[dim] =
2917 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2918 .getResult();
2919 else
2920 returnedDims[dim] =
2921 builder.getIndexAttr(inputType.getDimSize(dimInInput));
2922 }
2923
2924 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2925 return success();
2926}
2927
2928LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2929 MLIRContext *context, ::std::optional<Location> location,
2930 GatherOp::Adaptor adaptor,
2931 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2932 llvm::SmallVector<int64_t> outputShape;
2933 outputShape.resize(3, ShapedType::kDynamic);
2934
2935 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2936 if (valuesShape.hasRank()) {
2937 outputShape[0] = valuesShape.getDimSize(0);
2938 outputShape[2] = valuesShape.getDimSize(2);
2939 }
2940
2941 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2942 if (indicesShape.hasRank()) {
2943 if (outputShape[0] == ShapedType::kDynamic)
2944 outputShape[0] = indicesShape.getDimSize(0);
2945 if (outputShape[1] == ShapedType::kDynamic)
2946 outputShape[1] = indicesShape.getDimSize(1);
2947 }
2948
2949 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2950 return success();
2951}
2952
2953LogicalResult tosa::GatherOp::verify() {
2954 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2955 /* outType = */ getOutput().getType())
2956 .failed()) {
2957 return failure();
2958 }
2959
2960 const ShapeAdaptor valuesShape(getValues().getType());
2961 const ShapeAdaptor indicesShape(getIndices().getType());
2962 const ShapeAdaptor outputShape(getOutput().getType());
2963
2964 int64_t n = ShapedType::kDynamic;
2965 int64_t w = ShapedType::kDynamic;
2966 int64_t c = ShapedType::kDynamic;
2967
2968 if (valuesShape.hasRank()) {
2969 n = valuesShape.getDimSize(0);
2970 c = valuesShape.getDimSize(2);
2971 }
2972 if (indicesShape.hasRank()) {
2973 const int64_t indicesN = indicesShape.getDimSize(0);
2974 w = indicesShape.getDimSize(1);
2975 if (n == ShapedType::kDynamic)
2976 n = indicesN;
2977 else if (indicesN != ShapedType::kDynamic && n != indicesN)
2978 return emitOpError() << "requires indices dimension 0 to have size " << n
2979 << ", got " << indicesN;
2980 }
2981 if (outputShape.hasRank()) {
2982 const int64_t outputN = outputShape.getDimSize(0);
2983 const int64_t outputW = outputShape.getDimSize(1);
2984 const int64_t outputC = outputShape.getDimSize(2);
2985 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2986 n != outputN)
2987 return emitOpError() << "requires output dimension 0 to have size " << n
2988 << ", got " << outputN;
2989
2990 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2991 w != outputW)
2992 return emitOpError() << "requires output dimension 1 to have size " << w
2993 << ", got " << outputW;
2994 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2995 c != outputC)
2996 return emitOpError() << "requires output dimension 2 to have size " << c
2997 << ", got " << outputC;
2998 }
2999 return success();
3000}
3001
3002LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3003 MLIRContext *context, ::std::optional<Location> location,
3004 ResizeOp::Adaptor adaptor,
3005 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3006 llvm::SmallVector<int64_t, 4> outputShape;
3007 outputShape.resize(4, ShapedType::kDynamic);
3008
3009 ShapeAdaptor inputShape(adaptor.getInput().getType());
3010 if (!inputShape.hasRank())
3011 return failure();
3012
3013 outputShape[0] = inputShape.getDimSize(0);
3014 outputShape[3] = inputShape.getDimSize(3);
3015 int64_t inputHeight = inputShape.getDimSize(1);
3016 int64_t inputWidth = inputShape.getDimSize(2);
3017
3018 if ((inputHeight == ShapedType::kDynamic) ||
3019 (inputWidth == ShapedType::kDynamic))
3020 return failure();
3021
3022 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3023 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
3024 scaleInt) ||
3025 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
3026 offsetInt) ||
3027 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
3028 borderInt)) {
3029 return failure();
3030 }
3031
3032 // Compute the output shape based on attributes: scale, offset, and border.
3033 const int64_t outputHeight =
3034 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3035 scaleInt[1]) +
3036 1;
3037
3038 const int64_t outputWidth =
3039 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3040 scaleInt[3]) +
3041 1;
3042
3043 if (outputHeight < 0 || outputWidth < 0) {
3044 return emitOptionalError(
3045 location,
3046 "calculated output height and width must be non-negative, "
3047 "got height = ",
3048 outputHeight, ", width = ", outputWidth);
3049 }
3050
3051 outputShape[1] = outputHeight;
3052 outputShape[2] = outputWidth;
3053 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3054 return success();
3055}
3056
3057LogicalResult tosa::ResizeOp::verify() {
3058 const Value input = getInput();
3059 const Value output = getOutput();
3060 const RankedTensorType inputType =
3061 llvm::dyn_cast<RankedTensorType>(input.getType());
3062 const RankedTensorType outputType =
3063 llvm::dyn_cast<RankedTensorType>(output.getType());
3064
3065 SmallVector<int64_t> scaleValues;
3066 SmallVector<int64_t> offsetValues;
3067 SmallVector<int64_t> borderValues;
3068 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3069 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3070 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3071 // Skip following checks if shape is not constant
3072 return success();
3073 }
3074
3075 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3076 return emitOpError("expect all scale values to be > 0, got ")
3077 << scaleValues;
3078
3079 const int64_t scaleYN = scaleValues[0];
3080 const int64_t scaleYD = scaleValues[1];
3081 const int64_t scaleXN = scaleValues[2];
3082 const int64_t scaleXD = scaleValues[3];
3083
3084 const int64_t offsetY = offsetValues[0];
3085 const int64_t offsetX = offsetValues[1];
3086
3087 const int64_t borderY = borderValues[0];
3088 const int64_t borderX = borderValues[1];
3089
3090 if (!inputType)
3091 return success();
3092 if (!outputType)
3093 return success();
3094
3095 const int64_t oh = outputType.getDimSize(1);
3096 const int64_t ow = outputType.getDimSize(2);
3097 const int64_t ih = inputType.getDimSize(1);
3098 const int64_t iw = inputType.getDimSize(2);
3099
3100 // Don't check with input height that could be broadcast (ih != 1)
3101 // since Linalg, a consumer of TOSA, expects broadcasting support
3102 // in resize to be available. Taking the cautious approach for now,
3103 // we can consider removing support for broadcasting later.
3104 if (ih != ShapedType::kDynamic && ih != 1) {
3105 const std::optional<int64_t> calculatedOutHeightMinusOne =
3106 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3107 if (!calculatedOutHeightMinusOne.has_value())
3108 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3109 "border_y ")
3110 << "to be wholly divisible by scale_y_d, got ((" << ih
3111 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3112 << ") / " << scaleYD;
3113 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3114 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3115 return emitOpError("calculated output height did not match expected: ")
3116 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3117 }
3118
3119 // Don't check with input width that could be broadcast (iw != 1)
3120 // since Linalg, a consumer of TOSA, expects broadcasting support
3121 // in resize to be available. Taking the cautious approach for now,
3122 // we can consider removing support for broadcasting later.
3123 if (iw != ShapedType::kDynamic && iw != 1) {
3124 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3125 const std::optional<int64_t> calculatedOutWidthMinusOne =
3126 idivCheck(scaledInWidth, scaleXD);
3127 if (!calculatedOutWidthMinusOne.has_value())
3128 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3129 "border_x ")
3130 << "to be wholly divisible by scale_x_d, got ((" << iw
3131 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3132 << ") / " << scaleXD;
3133 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3134 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3135 return emitOpError("calculated output width did not match expected: ")
3136 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3137 }
3138
3139 return success();
3140}
3141
3142LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3143 MLIRContext *context, ::std::optional<Location> location,
3144 ScatterOp::Adaptor adaptor,
3145 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3146 llvm::SmallVector<int64_t> outputShape;
3147 outputShape.resize(3, ShapedType::kDynamic);
3148
3149 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3150 if (valuesInShape.hasRank()) {
3151 outputShape[0] = valuesInShape.getDimSize(0);
3152 outputShape[1] = valuesInShape.getDimSize(1);
3153 outputShape[2] = valuesInShape.getDimSize(2);
3154 }
3155
3156 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3157 if (indicesShape.hasRank()) {
3158 if (outputShape[0] == ShapedType::kDynamic)
3159 outputShape[0] = indicesShape.getDimSize(0);
3160 }
3161
3162 ShapeAdaptor inputShape(adaptor.getInput().getType());
3163 if (inputShape.hasRank()) {
3164 if (outputShape[0] == ShapedType::kDynamic)
3165 outputShape[0] = inputShape.getDimSize(0);
3166 if (outputShape[2] == ShapedType::kDynamic)
3167 outputShape[2] = inputShape.getDimSize(2);
3168 }
3169
3170 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3171 return success();
3172}
3173
3174LogicalResult tosa::ScatterOp::verify() {
3175 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3176 /* outType = */ getValuesOut().getType())
3177 .failed() ||
3178 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3179 /* outType = */ getValuesOut().getType())
3180 .failed()) {
3181 return failure();
3182 }
3183
3184 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3185 const ShapeAdaptor indicesShape(getIndices().getType());
3186 const ShapeAdaptor inputShape(getInput().getType());
3187 const ShapeAdaptor outputShape(getValuesOut().getType());
3188
3189 int64_t n = ShapedType::kDynamic;
3190 int64_t k = ShapedType::kDynamic;
3191 int64_t w = ShapedType::kDynamic;
3192 int64_t c = ShapedType::kDynamic;
3193 if (valuesInShape.hasRank()) {
3194 n = valuesInShape.getDimSize(0);
3195 k = valuesInShape.getDimSize(1);
3196 c = valuesInShape.getDimSize(2);
3197 }
3198 if (indicesShape.hasRank()) {
3199 const int64_t indicesN = indicesShape.getDimSize(0);
3200 w = indicesShape.getDimSize(1);
3201 if (n == ShapedType::kDynamic)
3202 n = indicesN;
3203 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3204 return emitOpError() << "requires indices dimension 0 to have size " << n
3205 << ", got " << indicesN;
3206 }
3207 if (inputShape.hasRank()) {
3208 const int64_t inputN = inputShape.getDimSize(0);
3209 const int64_t inputW = inputShape.getDimSize(1);
3210 const int64_t inputC = inputShape.getDimSize(2);
3211 if (n == ShapedType::kDynamic)
3212 n = inputN;
3213 else if (inputN != ShapedType::kDynamic && n != inputN)
3214 return emitOpError() << "requires input dimension 0 to have size " << n
3215 << ", got " << inputN;
3216 if (w == ShapedType::kDynamic)
3217 w = inputW;
3218 else if (inputW != ShapedType::kDynamic && w != inputW)
3219 return emitOpError() << "requires input dimension 1 to have size " << w
3220 << ", got " << inputW;
3221
3222 if (c == ShapedType::kDynamic)
3223 c = inputC;
3224 else if (inputC != ShapedType::kDynamic && c != inputC)
3225 return emitOpError() << "requires input dimension 2 to have size " << c
3226 << ", got " << inputC;
3227 }
3228 if (outputShape.hasRank()) {
3229 const int64_t outputN = outputShape.getDimSize(0);
3230 const int64_t outputK = outputShape.getDimSize(1);
3231 const int64_t outputC = outputShape.getDimSize(2);
3232 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3233 n != outputN)
3234 return emitOpError() << "requires values_out dimension 0 to have size "
3235 << n << ", got " << outputN;
3236 if (k == ShapedType::kDynamic)
3237 k = outputK;
3238 else if (outputK != ShapedType::kDynamic && k != outputK)
3239 return emitOpError() << "requires values_out dimension 1 to have size "
3240 << k << ", got " << outputK;
3241 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3242 c != outputC)
3243 return emitOpError() << "requires values_out dimension 2 to have size "
3244 << c << ", got " << outputC;
3245 }
3246 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3247 return emitOpError() << "requires dimensions K >= W, got K=" << k
3248 << " and W=" << w;
3249
3250 return success();
3251}
3252
3253static LogicalResult ReduceInferReturnTypes(
3254 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3255 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3256 int64_t axisVal = axis.getValue().getSExtValue();
3257 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3258 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3259 return success();
3260 }
3261
3262 SmallVector<int64_t> outputShape;
3263 operandShape.getDims(outputShape);
3264 outputShape[axisVal] = 1;
3265 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3266 return success();
3267}
3268
3269#define COMPATIBLE_RETURN_TYPES(OP) \
3270 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3271 if (l.size() != r.size() || l.size() != 1) \
3272 return false; \
3273 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3274 return false; \
3275 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3276 }
3277
3278#define REDUCE_SHAPE_INFER(OP) \
3279 LogicalResult OP::inferReturnTypeComponents( \
3280 MLIRContext *context, ::std::optional<Location> location, \
3281 OP::Adaptor adaptor, \
3282 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3283 Type inputType = \
3284 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3285 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3286 const Properties &prop = adaptor.getProperties(); \
3287 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3288 inferredReturnShapes); \
3289 } \
3290 COMPATIBLE_RETURN_TYPES(OP)
3291
3292REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3293REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3294REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3295REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3296REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3297REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3298#undef REDUCE_SHAPE_INFER
3299COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3300#undef COMPATIBLE_RETURN_TYPES
3301
3302template <typename T>
3303static LogicalResult verifyReduceOp(T op) {
3304 // All TOSA reduce Ops have input, output and axis.
3305 TensorType inputType = op.getInput().getType();
3306 TensorType outputType = op.getOutput().getType();
3307 int32_t reduceAxis = op.getAxis();
3308
3309 if (reduceAxis < 0) {
3310 op.emitOpError("reduce axis must not be negative");
3311 return failure();
3312 }
3313 if (inputType.hasRank()) {
3314 int64_t inputRank = inputType.getRank();
3315 // We allow for a special case where the input/output shape has rank 0 and
3316 // axis is also 0.
3317 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3318 op.emitOpError("expect input tensor rank (")
3319 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3320 << ")";
3321 return failure();
3322 }
3323 }
3324 if (outputType.hasRank()) {
3325 int64_t outputRank = outputType.getRank();
3326 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3327 op.emitOpError(
3328 "expect output tensor rank to be equal to input tensor rank");
3329 return failure();
3330 }
3331 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3332 op.emitOpError("expect output tensor rank (")
3333 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3334 << ")";
3335 return failure();
3336 }
3337 // We can only verify the reduced dimension size to be 1 if this is not
3338 // the special case of output rank == 0.
3339 if (outputRank != 0) {
3340 auto outputShape = outputType.getShape();
3341 if (!outputType.isDynamicDim(reduceAxis) &&
3342 outputShape[reduceAxis] != 1) {
3343 op.emitOpError("expect reduced dimension size to be 1, got ")
3344 << outputShape[reduceAxis];
3345 return failure();
3346 }
3347 }
3348 }
3349 return success();
3350}
3351
3352LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3353LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3354LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3355LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3356LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3357LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3358
3359static LogicalResult NAryInferReturnTypes(
3360 const ValueShapeRange &operands,
3361 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3363 if (resolveBroadcastShape(operands, outShape).failed()) {
3364 inferredReturnShapes.push_back(ShapedTypeComponents());
3365 } else {
3366 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3367 }
3368 return success();
3369}
3370
3371#define NARY_SHAPE_INFER(OP) \
3372 LogicalResult OP::inferReturnTypeComponents( \
3373 MLIRContext *context, ::std::optional<Location> location, \
3374 ValueShapeRange operands, DictionaryAttr attributes, \
3375 OpaqueProperties properties, RegionRange regions, \
3376 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3377 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3378 }
3379
3380NARY_SHAPE_INFER(tosa::AbsOp)
3381NARY_SHAPE_INFER(tosa::AddOp)
3382NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3383NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3384NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3385NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3386NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3387NARY_SHAPE_INFER(tosa::CastOp)
3388NARY_SHAPE_INFER(tosa::CeilOp)
3389NARY_SHAPE_INFER(tosa::ClampOp)
3390NARY_SHAPE_INFER(tosa::ClzOp)
3391NARY_SHAPE_INFER(tosa::CosOp)
3392NARY_SHAPE_INFER(tosa::ExpOp)
3393NARY_SHAPE_INFER(tosa::FloorOp)
3394NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3395NARY_SHAPE_INFER(tosa::GreaterOp)
3396NARY_SHAPE_INFER(tosa::IdentityOp)
3397NARY_SHAPE_INFER(tosa::IntDivOp)
3398NARY_SHAPE_INFER(tosa::LogOp)
3399NARY_SHAPE_INFER(tosa::LogicalAndOp)
3400NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3401NARY_SHAPE_INFER(tosa::LogicalNotOp)
3402NARY_SHAPE_INFER(tosa::LogicalOrOp)
3403NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3404NARY_SHAPE_INFER(tosa::LogicalXorOp)
3405NARY_SHAPE_INFER(tosa::MaximumOp)
3406NARY_SHAPE_INFER(tosa::MinimumOp)
3407NARY_SHAPE_INFER(tosa::PowOp)
3408NARY_SHAPE_INFER(tosa::ReciprocalOp)
3409NARY_SHAPE_INFER(tosa::ReverseOp)
3410NARY_SHAPE_INFER(tosa::RsqrtOp)
3411NARY_SHAPE_INFER(tosa::SinOp)
3412NARY_SHAPE_INFER(tosa::SelectOp)
3413NARY_SHAPE_INFER(tosa::SubOp)
3414NARY_SHAPE_INFER(tosa::TanhOp)
3415NARY_SHAPE_INFER(tosa::ErfOp)
3416NARY_SHAPE_INFER(tosa::SigmoidOp)
3417#undef PRED_SHAPE_INFER
3418
3419LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3420 MLIRContext *context, ::std::optional<Location> location,
3421 NegateOp::Adaptor adaptor,
3422 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3423 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3424 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3425 return success();
3426}
3427
3428LogicalResult tosa::NegateOp::verify() {
3429 // Verify same element type
3430 const Type input1Type = getInput1().getType();
3431 const Type outputType = getOutput().getType();
3432 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3433 return failure();
3434
3435 // Verify same shape
3436 const SmallVector<Type, 2> types = {input1Type, outputType};
3437 if (failed(verifyCompatibleShapes(types)))
3438 return emitOpError() << "requires the same shape for input1 and output";
3439
3440 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3441 const Type input1ZpEType =
3442 getStorageElementTypeOrSelf(getInput1Zp().getType());
3443 if (input1EType != input1ZpEType) {
3444 return emitOpError("expect both input1 and its zero point are the same "
3445 "element type, got ")
3446 << input1EType << " and " << input1ZpEType;
3447 }
3448 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3449 const Type outputZpEType =
3450 getStorageElementTypeOrSelf(getOutputZp().getType());
3451 if (outputEType != outputZpEType) {
3452 return emitOpError("expect both output and its zero point are the same "
3453 "element type, got ")
3454 << outputEType << " and " << outputZpEType;
3455 }
3456
3457 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3458 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3459 return failure();
3460
3461 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3462 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3463 return failure();
3464
3465 return success();
3466}
3467
3468static LogicalResult poolingInferReturnTypes(
3469 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3471 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3472 llvm::SmallVector<int64_t> outputShape;
3473 outputShape.resize(4, ShapedType::kDynamic);
3474
3475 // We only know the rank if the input type is unranked.
3476 if (!inputShape.hasRank()) {
3477 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3478 return success();
3479 }
3480
3481 // Batch and number of channels are identical for pooling layer.
3482 outputShape[0] = inputShape.getDimSize(0);
3483 outputShape[3] = inputShape.getDimSize(3);
3484
3485 int64_t height = inputShape.getDimSize(1);
3486 int64_t width = inputShape.getDimSize(2);
3487
3488 if (ShapedType::isStatic(height)) {
3489 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3490 outputShape[1] = padded / stride[0] + 1;
3491 }
3492
3493 if (ShapedType::isStatic(width)) {
3494 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3495 outputShape[2] = padded / stride[1] + 1;
3496 }
3497
3498 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3499 return success();
3500}
3501
3502template <typename AdaptorT>
3504
3506protected:
3507 static void updateIfDynamic(int64_t &current, int64_t candidate) {
3508 if (ShapedType::isDynamic(current))
3509 current = candidate;
3510 }
3511};
3512
3513template <>
3514class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
3515 : public ConvInferShapeAdaptorBase {
3516public:
3517 explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
3518 : adaptor(adaptor) {}
3519
3521 SmallVectorImpl<int64_t> &inputSpatial) {
3522 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3523 if (!inputShape.hasRank())
3524 return;
3525
3526 const int64_t outputBatch = inputShape.getDimSize(0);
3527 const int64_t inputHeight = inputShape.getDimSize(1);
3528 const int64_t inputWidth = inputShape.getDimSize(2);
3529
3530 outputShape[0] = outputBatch;
3531 inputSpatial[0] = inputHeight;
3532 inputSpatial[1] = inputWidth;
3533 }
3534
3536 SmallVectorImpl<int64_t> &weightSpatial) {
3537 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3538 if (!weightShape.hasRank())
3539 return;
3540
3541 const int64_t outputChannels = weightShape.getDimSize(0);
3542 const int64_t kernelHeight = weightShape.getDimSize(1);
3543 const int64_t kernelWidth = weightShape.getDimSize(2);
3544
3545 outputShape[3] = outputChannels;
3546 weightSpatial[0] = kernelHeight;
3547 weightSpatial[1] = kernelWidth;
3548 }
3549
3550 int64_t getNumSpatialDims() const { return 2; }
3551 int64_t getOutputRank() const { return 4; }
3552
3554 SmallVector<int64_t> &strideValues,
3555 SmallVector<int64_t> &dilationValues) {
3556 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3557 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3558 dilationValues.assign(adaptor.getDilation().begin(),
3559 adaptor.getDilation().end());
3560 return success();
3561 }
3562
3563private:
3564 Conv2DOp::Adaptor adaptor;
3565};
3566
3567template <>
3568class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
3569 : public ConvInferShapeAdaptorBase {
3570public:
3571 explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
3572 : adaptor(adaptor) {}
3573
3575 SmallVectorImpl<int64_t> &inputSpatial) {
3576 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
3577 if (inputDataShape.hasRank()) {
3578 const int64_t outputBatch = inputDataShape.getDimSize(0);
3579 const int64_t inputHeight = inputDataShape.getDimSize(1);
3580 const int64_t inputWidth = inputDataShape.getDimSize(2);
3581
3582 outputShape[0] = outputBatch;
3583 inputSpatial[0] = inputHeight;
3584 inputSpatial[1] = inputWidth;
3585 }
3586
3587 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
3588 if (!inputScaleShape.hasRank())
3589 return;
3590
3591 const int64_t scaleBatch = inputScaleShape.getDimSize(0);
3592 const int64_t scaleHeight = inputScaleShape.getDimSize(1);
3593 const int64_t scaleWidth = inputScaleShape.getDimSize(2);
3594
3595 updateIfDynamic(outputShape[0], scaleBatch);
3596 updateIfDynamic(inputSpatial[0], scaleHeight);
3597 updateIfDynamic(inputSpatial[1], scaleWidth);
3598 }
3599
3601 SmallVectorImpl<int64_t> &weightSpatial) {
3602 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
3603 if (weightDataShape.hasRank()) {
3604 const int64_t outputChannels = weightDataShape.getDimSize(0);
3605 const int64_t kernelHeight = weightDataShape.getDimSize(1);
3606 const int64_t kernelWidth = weightDataShape.getDimSize(2);
3607
3608 outputShape[3] = outputChannels;
3609 weightSpatial[0] = kernelHeight;
3610 weightSpatial[1] = kernelWidth;
3611 }
3612
3613 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
3614 if (!weightScaleShape.hasRank())
3615 return;
3616
3617 const int64_t scaleOutputChannels = weightScaleShape.getDimSize(0);
3618 const int64_t scaleKernelHeight = weightScaleShape.getDimSize(1);
3619 const int64_t scaleKernelWidth = weightScaleShape.getDimSize(2);
3620
3621 updateIfDynamic(outputShape[3], scaleOutputChannels);
3622 updateIfDynamic(weightSpatial[0], scaleKernelHeight);
3623 updateIfDynamic(weightSpatial[1], scaleKernelWidth);
3624 }
3625
3626 int64_t getNumSpatialDims() const { return 2; }
3627 int64_t getOutputRank() const { return 4; }
3628
3630 SmallVector<int64_t> &strideValues,
3631 SmallVector<int64_t> &dilationValues) {
3632 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
3633 padValues) ||
3634 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
3635 strideValues) ||
3636 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
3637 dilationValues))
3638 return failure();
3639 return success();
3640 }
3641
3642private:
3643 Conv2DBlockScaledOp::Adaptor adaptor;
3644};
3645
3646template <>
3647class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
3648 : public ConvInferShapeAdaptorBase {
3649public:
3650 explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
3651 : adaptor(adaptor) {}
3652
3654 SmallVectorImpl<int64_t> &inputSpatial) {
3655 const ShapeAdaptor inputShape(adaptor.getInput().getType());
3656 if (!inputShape.hasRank())
3657 return;
3658
3659 const int64_t outputBatch = inputShape.getDimSize(0);
3660 const int64_t inputDepth = inputShape.getDimSize(1);
3661 const int64_t inputHeight = inputShape.getDimSize(2);
3662 const int64_t inputWidth = inputShape.getDimSize(3);
3663
3664 outputShape[0] = outputBatch;
3665 inputSpatial[0] = inputDepth;
3666 inputSpatial[1] = inputHeight;
3667 inputSpatial[2] = inputWidth;
3668 }
3669
3671 SmallVectorImpl<int64_t> &weightSpatial) {
3672 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
3673 if (!weightShape.hasRank())
3674 return;
3675
3676 const int64_t outputChannels = weightShape.getDimSize(0);
3677 const int64_t kernelDepth = weightShape.getDimSize(1);
3678 const int64_t kernelHeight = weightShape.getDimSize(2);
3679 const int64_t kernelWidth = weightShape.getDimSize(3);
3680
3681 outputShape[4] = outputChannels;
3682 weightSpatial[0] = kernelDepth;
3683 weightSpatial[1] = kernelHeight;
3684 weightSpatial[2] = kernelWidth;
3685 }
3686
3687 int64_t getNumSpatialDims() const { return 3; }
3688 int64_t getOutputRank() const { return 5; }
3689
3691 SmallVector<int64_t> &strideValues,
3692 SmallVector<int64_t> &dilationValues) {
3693 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
3694 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
3695 dilationValues.assign(adaptor.getDilation().begin(),
3696 adaptor.getDilation().end());
3697 return success();
3698 }
3699
3700private:
3701 Conv3DOp::Adaptor adaptor;
3702};
3703
3704template <typename AdaptorT>
3706 AdaptorT adaptor,
3707 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3708 ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
3709 llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
3710 ShapedType::kDynamic);
3711 llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
3712 ShapedType::kDynamic);
3713 llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
3714 ShapedType::kDynamic);
3715
3716 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
3717 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
3718
3719 const ShapeAdaptor biasShape = adaptor.getBias().getType();
3720 if (biasShape.hasRank()) {
3721 const int64_t biasSize = biasShape.getDimSize(0);
3722 if (biasSize != 1) {
3723 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
3724 outputShape[outputChannelDim] =
3725 ShapedType::isDynamic(outputShape[outputChannelDim])
3726 ? biasSize
3727 : outputShape[outputChannelDim];
3728 }
3729 }
3730
3731 SmallVector<int64_t> padValues;
3732 SmallVector<int64_t> strideValues;
3733 SmallVector<int64_t> dilationValues;
3734 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
3735 dilationValues))) {
3736 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3737 return success();
3738 }
3739
3740 for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
3741 if (!ShapedType::isStatic(inputSpatial[dim]) ||
3742 !ShapedType::isStatic(weightSpatial[dim]))
3743 continue;
3744 const int64_t inputSize =
3745 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
3746 const int64_t filterSize =
3747 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
3748 const int64_t unstridedResult = inputSize - filterSize + 1;
3749 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
3750 }
3751
3752 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3753 return success();
3754}
3755
3756LogicalResult Conv2DOp::inferReturnTypeComponents(
3757 MLIRContext *context, ::std::optional<Location> location,
3758 Conv2DOp::Adaptor adaptor,
3759 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3760 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3761}
3762
3763LogicalResult Conv2DOp::verify() {
3764 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3765 verifyConvOpErrorIf(*this).failed())
3766 return failure();
3767 return success();
3768}
3769
3770LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
3771 MLIRContext *context, ::std::optional<Location> location,
3772 Conv2DBlockScaledOp::Adaptor adaptor,
3773 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3774 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3775}
3776
3777LogicalResult Conv2DBlockScaledOp::verify() {
3778 if (failed(verifySameElementTypes(*this, getInputData().getType(),
3779 getWeightData().getType(), "input_data",
3780 "weight_data")) ||
3781 failed(verifySameElementTypes(*this, getInputScale().getType(),
3782 getWeightScale().getType(), "input_scale",
3783 "weight_scale")) ||
3784 failed(verifySameElementTypes(*this, getBias().getType(),
3785 getOutput().getType(), "bias", "output")))
3786 return failure();
3787
3788 // Verify input shape compatibility
3789 int64_t N = ShapedType::kDynamic;
3790 int64_t IH = ShapedType::kDynamic;
3791 int64_t IW = ShapedType::kDynamic;
3792 int64_t IC = ShapedType::kDynamic;
3793 int64_t multiplesOfIC = ShapedType::kDynamic;
3794 int64_t OC = ShapedType::kDynamic;
3795 int64_t KH = ShapedType::kDynamic;
3796 int64_t KW = ShapedType::kDynamic;
3797
3798 const ShapeAdaptor inputDataShape(getInputData().getType());
3799 if (inputDataShape.hasRank()) {
3800 N = inputDataShape.getDimSize(0);
3801 IH = inputDataShape.getDimSize(1);
3802 IW = inputDataShape.getDimSize(2);
3803 IC = inputDataShape.getDimSize(3);
3804 }
3805
3806 const ShapeAdaptor inputScaleShape(getInputScale().getType());
3807 if (inputScaleShape.hasRank()) {
3808 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
3809 "input_scale", "batch size")) ||
3810 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
3811 "input_scale", "input height")) ||
3812 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
3813 "input_scale", "input width")))
3814 return failure();
3815 multiplesOfIC = inputScaleShape.getDimSize(3);
3816 }
3817
3818 const ShapeAdaptor weightDataShape(getWeightData().getType());
3819 if (weightDataShape.hasRank()) {
3820 OC = weightDataShape.getDimSize(0);
3821 KH = weightDataShape.getDimSize(1);
3822 KW = weightDataShape.getDimSize(2);
3823 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
3824 "weight_data", "input channels")))
3825 return failure();
3826 }
3827
3828 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
3829 if (weightScaleShape.hasRank()) {
3830 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
3831 "weight_scale", "output channels")) ||
3832 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
3833 "weight_scale", "kernel height")) ||
3834 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
3835 "weight_scale", "kernel width")) ||
3836 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
3837 weightScaleShape.getDimSize(3),
3838 "weight_scale", "input channel blocks")))
3839 return failure();
3840 }
3841
3842 // Verify IC is a multiple of block size
3843 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3844 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
3845 return emitOpError("expect IC to be a multiple of block size, got IC=")
3846 << IC << ", block_size=" << blockSize;
3847
3848 // Verify multiplesOfIC is IC / block size
3849 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
3850 multiplesOfIC != IC / blockSize)
3851 return emitOpError(
3852 "expect scale operands dimension 2 to equal IC/block_size (")
3853 << IC << "/" << blockSize << ")"
3854 << ", got " << multiplesOfIC;
3855
3856 // Verify pad/stride/dilation values
3857 SmallVector<int64_t> padValues;
3858 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
3859 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
3860 return emitOpError("expect all padding values to be >= 0, got ")
3861 << padValues;
3862 }
3863
3864 SmallVector<int64_t> strideValues;
3865 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
3866 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
3867 return emitOpError("expect all stride values to be >= 1, got ")
3868 << strideValues;
3869 }
3870
3871 SmallVector<int64_t> dilationValues;
3872 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
3873 dilationValues)) {
3874 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
3875 return emitOpError("expect all dilation values to be >= 1, got ")
3876 << dilationValues;
3877 }
3878
3879 // Verify output shape compatibility
3880 const ShapeAdaptor outputShape(getOutput().getType());
3881 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
3882 outputShape.hasRank()) {
3883 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
3884 padValues[0], padValues[1], strideValues[0],
3885 dilationValues[0], "height", "y", "top",
3886 "bottom")) ||
3887 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
3888 padValues[2], padValues[3], strideValues[1],
3889 dilationValues[1], "width", "x", "left",
3890 "right")))
3891 return failure();
3892 }
3893
3894 // Verify bias
3895 const ShapeAdaptor biasShape(getBias().getType());
3896 if (biasShape.hasRank() && outputShape.hasRank()) {
3897 const int64_t biasChannels = biasShape.getDimSize(0);
3898 const int64_t outputChannels =
3899 outputShape.getDimSize(outputShape.getRank() - 1);
3900 if (biasChannels == ShapedType::kDynamic ||
3901 outputChannels == ShapedType::kDynamic)
3902 // Skip following checks if biasChannels or outputChannels is dynamic dim
3903 return success();
3904
3905 if (biasChannels != outputChannels && biasChannels != 1)
3906 return emitOpError(
3907 "bias channels expected to be equal to output channels (")
3908 << outputChannels << ") or 1, got " << biasChannels;
3909 }
3910
3911 return success();
3912}
3913
3914LogicalResult Conv3DOp::inferReturnTypeComponents(
3915 MLIRContext *context, ::std::optional<Location> location,
3916 Conv3DOp::Adaptor adaptor,
3917 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3918 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
3919}
3920
3921LogicalResult Conv3DOp::verify() {
3922 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3923 verifyConvOpErrorIf(*this).failed())
3924 return failure();
3925 return success();
3926}
3927
3928LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3929 MLIRContext *context, ::std::optional<Location> location,
3930 AvgPool2dOp::Adaptor adaptor,
3931 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3932 ShapeAdaptor inputShape(adaptor.getInput().getType());
3933 const Properties &prop = adaptor.getProperties();
3934 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3935 inferredReturnShapes);
3936}
3937
3938LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3939 MLIRContext *context, ::std::optional<Location> location,
3940 MaxPool2dOp::Adaptor adaptor,
3941 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3942 ShapeAdaptor inputShape(adaptor.getInput().getType());
3943 const Properties &prop = adaptor.getProperties();
3944 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3945 inferredReturnShapes);
3946}
3947
3948LogicalResult MaxPool2dOp::verify() {
3949 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3950 /* outType = */ getOutput().getType())))
3951 return failure();
3952
3953 if (failed(verifyPoolingOp(*this)))
3954 return failure();
3955
3956 return success();
3957}
3958
3959LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3960 MLIRContext *context, ::std::optional<Location> location,
3961 DepthwiseConv2DOp::Adaptor adaptor,
3962 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3963 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3964
3965 int64_t inputWidth = ShapedType::kDynamic;
3966 int64_t inputHeight = ShapedType::kDynamic;
3967 int64_t inputChannels = ShapedType::kDynamic;
3968
3969 int64_t weightWidth = ShapedType::kDynamic;
3970 int64_t weightHeight = ShapedType::kDynamic;
3971 int64_t depthChannels = ShapedType::kDynamic;
3972
3973 // Input shape describes input width/height and batch.
3974 ShapeAdaptor inputShape(adaptor.getInput().getType());
3975 if (inputShape.hasRank()) {
3976 outputShape[0] = inputShape.getDimSize(0);
3977 inputHeight = inputShape.getDimSize(1);
3978 inputWidth = inputShape.getDimSize(2);
3979 inputChannels = inputShape.getDimSize(3);
3980 }
3981
3982 // Weight shapes describes the filter width/height and the output channels.
3983 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3984 if (weightShape.hasRank()) {
3985 weightHeight = weightShape.getDimSize(0);
3986 weightWidth = weightShape.getDimSize(1);
3987 inputChannels = ShapedType::isDynamic(inputChannels)
3988 ? weightShape.getDimSize(2)
3989 : inputChannels;
3990 depthChannels = weightShape.getDimSize(3);
3991 }
3992
3993 // If both inputChannels and depthChannels are available we can determine
3994 // the output channels.
3995 if (ShapedType::isStatic(inputChannels) &&
3996 ShapedType::isStatic(depthChannels)) {
3997 outputShape[3] = inputChannels * depthChannels;
3998 }
3999
4000 // Bias shape can describe the output channels.
4001 ShapeAdaptor biasShape(adaptor.getBias().getType());
4002 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4003 int64_t bc = biasShape.getDimSize(0);
4004 if (bc != ShapedType::kDynamic && bc != 1)
4005 outputShape[3] = bc;
4006 }
4007
4008 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4009 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4010 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4011
4012 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4013 int64_t inputSize = inputHeight + padding[0] + padding[1];
4014 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4015 int64_t unstridedResult = inputSize - filterSize + 1;
4016 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4017 }
4018
4019 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4020 int64_t inputSize = inputWidth + padding[2] + padding[3];
4021 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4022 int64_t unstridedResult = inputSize - filterSize + 1;
4023 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4024 }
4025
4026 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4027 return success();
4028}
4029
4030LogicalResult DepthwiseConv2DOp::verify() {
4031 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4032 verifyConvOpErrorIf(*this).failed())
4033 return failure();
4034 return success();
4035}
4036
4037LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4038 MLIRContext *context, ::std::optional<Location> location,
4039 TransposeConv2DOp::Adaptor adaptor,
4040 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4041 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4042
4043 int64_t inputWidth = ShapedType::kDynamic;
4044 int64_t inputHeight = ShapedType::kDynamic;
4045 int64_t weightWidth = ShapedType::kDynamic;
4046 int64_t weightHeight = ShapedType::kDynamic;
4047
4048 // Input shape describes input width/height and batch.
4049 ShapeAdaptor inputShape(adaptor.getInput().getType());
4050 if (inputShape.hasRank()) {
4051 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4052 ? inputShape.getDimSize(0)
4053 : outputShape[0];
4054 inputHeight = inputShape.getDimSize(1);
4055 inputWidth = inputShape.getDimSize(2);
4056 }
4057
4058 // Weight shapes describes the filter width/height and the output channels.
4059 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4060 if (weightShape.hasRank()) {
4061 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4062 ? weightShape.getDimSize(0)
4063 : outputShape[3];
4064 weightHeight = weightShape.getDimSize(1);
4065 weightWidth = weightShape.getDimSize(2);
4066 }
4067
4068 // Bias shape can describe the output channels.
4069 ShapeAdaptor biasShape(adaptor.getBias().getType());
4070 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4071 int64_t bc = biasShape.getDimSize(0);
4072 if (bc != ShapedType::kDynamic && bc != 1)
4073 outputShape[3] = bc;
4074 }
4075
4076 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4077 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4078
4079 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4080 int64_t calculateSize =
4081 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4082 outputShape[1] =
4083 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4084 }
4085
4086 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4087 int64_t calculateSize =
4088 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4089 outputShape[2] =
4090 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4091 }
4092
4093 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4094 return success();
4095}
4096
4097LogicalResult TransposeConv2DOp::verify() {
4098 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
4099 return failure();
4100
4101 const llvm::ArrayRef<int64_t> strides = getStride();
4102 const int64_t strideY = strides[0];
4103 const int64_t strideX = strides[1];
4104
4105 if (strideY < 1 || strideX < 1)
4106 return emitOpError("expect all stride values to be >= 1, got [")
4107 << strides << "]";
4108
4109 const auto checkPadAgainstKernelDim =
4110 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4111 llvm::StringRef kernelDimName) -> LogicalResult {
4112 if (padValue <= -kernelDimSize)
4113 return emitOpError("expected ")
4114 << padName << " > -" << kernelDimName << ", but got: " << padName
4115 << "=" << padValue << " and " << kernelDimName << "="
4116 << kernelDimSize;
4117 return success();
4118 };
4119
4120 const llvm::ArrayRef<int64_t> padding = getOutPad();
4121 const int64_t outPadTop = padding[0];
4122 const int64_t outPadBottom = padding[1];
4123 const int64_t outPadLeft = padding[2];
4124 const int64_t outPadRight = padding[3];
4125
4126 const auto weightType =
4127 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4128
4129 if (weightType) {
4130 const int64_t kernelHeight = weightType.getDimSize(1);
4131 if (ShapedType::isStatic(kernelHeight)) {
4132 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4133 "out_pad_top", "KH")))
4134 return failure();
4135
4136 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4137 "out_pad_bottom", "KH")))
4138 return failure();
4139 }
4140
4141 const int64_t kernelWidth = weightType.getDimSize(2);
4142 if (ShapedType::isStatic(kernelWidth)) {
4143 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4144 "out_pad_left", "KW")))
4145 return failure();
4146
4147 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4148 "out_pad_right", "KW")))
4149 return failure();
4150 }
4151 }
4152
4153 // Rest of the checks depend on the output type being a RankedTensorType
4154 const auto outputType =
4155 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4156 if (!outputType)
4157 return success();
4158
4159 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4160 if (inputType && weightType) {
4161 const int64_t inputHeight = inputType.getDimSize(1);
4162 const int64_t kernelHeight = weightType.getDimSize(1);
4163 const int64_t outputHeight = outputType.getDimSize(1);
4164
4165 if (ShapedType::isStatic(inputHeight) &&
4166 ShapedType::isStatic(outputHeight)) {
4167 if (outputHeight !=
4168 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4169 return emitOpError(
4170 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4171 "+ out_pad_top + out_pad_bottom + KH, but got ")
4172 << outputHeight << " != (" << inputHeight << " - 1) * "
4173 << strideY << " + " << outPadTop << " + " << outPadBottom
4174 << " + " << kernelHeight;
4175 }
4176
4177 const int64_t inputWidth = inputType.getDimSize(2);
4178 const int64_t kernelWidth = weightType.getDimSize(2);
4179 const int64_t outputWidth = outputType.getDimSize(2);
4180
4181 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4182 if (outputWidth !=
4183 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4184 return emitOpError(
4185 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4186 "+ out_pad_left + out_pad_right + KW, but got ")
4187 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4188 << " + " << outPadLeft << " + " << outPadRight << " + "
4189 << kernelWidth;
4190 }
4191 }
4192
4193 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4194
4195 if (!biasType)
4196 return success();
4197
4198 const int64_t biasChannels = biasType.getDimSize(0);
4199
4200 // Skip further checks if bias is dynamic
4201 if (biasChannels == ShapedType::kDynamic)
4202 return success();
4203
4204 const int64_t outputChannels = outputType.getDimSize(3);
4205 if (!ShapedType::isDynamic(outputChannels) &&
4206 biasChannels != outputChannels && biasChannels != 1)
4207 return emitOpError(
4208 "bias channels expected to be equal to output channels (")
4209 << outputChannels << ") or 1, got " << biasChannels;
4210
4211 return success();
4212}
4213
4214LogicalResult RescaleOp::verify() {
4215 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
4216 if (!inputType) {
4217 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
4218 return failure();
4219 }
4220
4221 auto inputElementType =
4222 getStorageElementTypeOrSelf(inputType.getElementType());
4223 if (!mlir::isa<IntegerType>(inputElementType)) {
4224 emitOpError("expect input to have integer element type, got ")
4225 << inputElementType;
4226 return failure();
4227 }
4228
4229 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
4230 if (!outputType) {
4231 emitOpError("expect shaped tensor for output, got ")
4232 << getOutput().getType();
4233 return failure();
4234 }
4235
4236 auto outputElementType =
4237 getStorageElementTypeOrSelf(outputType.getElementType());
4238 if (!mlir::isa<IntegerType>(outputElementType)) {
4239 emitOpError("expect output to have integer element type, got ")
4240 << outputElementType;
4241 return failure();
4242 }
4243
4244 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4245 .failed())
4246 return failure();
4247
4248 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4249 .failed())
4250 return failure();
4251
4252 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4253 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4254 return failure();
4255
4256 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4257 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4258 return failure();
4259
4260 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
4261 if (!multiplierType) {
4262 emitOpError("expect shaped tensor for multiplier, got ")
4263 << getMultiplier().getType();
4264 return failure();
4265 }
4266
4267 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
4268 if (!shiftType) {
4269 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
4270 return failure();
4271 }
4272
4273 // multiplier element type must be i32 for scale32 = true
4274 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4275 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4276 << multiplierType.getElementType();
4277 return failure();
4278 }
4279
4280 // multiplier element type must be i16 for scale32 = false
4281 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4283 "expect i16 element type for multiplier for scale32=false, got ")
4284 << multiplierType.getElementType();
4285 return failure();
4286 }
4287
4288 if (!inputType.hasRank())
4289 return success();
4290
4291 // multiplier/shift must have shape = {numChannels},
4292 // where numChannel is 1 if per_channel = false
4293 // otherwise numChannel is dimension in input shape's last axis
4294 int64_t numChannels = 1;
4295 if (getPerChannel()) {
4296 if (inputType.getRank() < 1) {
4297 emitOpError("requires input to be at least rank 1 when per_channel is "
4298 "true, but got rank ")
4299 << inputType.getRank();
4300 return failure();
4301 }
4302 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4303 }
4304
4305 if (!multiplierType.hasRank())
4306 return success();
4307
4308 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4309 // multiplier input has rank 1 by dialect definition
4310 if (multiplierShape[0] != ShapedType::kDynamic &&
4311 multiplierShape[0] != numChannels) {
4312 emitOpError("expect shape of { ")
4313 << numChannels << " } for multiplier input, got { "
4314 << multiplierShape[0] << " }";
4315 return failure();
4316 }
4317
4318 if (!shiftType.hasRank())
4319 return success();
4320
4321 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4322 // shift input has rank 1 by dialect definition
4323 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4324 emitOpError("expect shape of { ")
4325 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
4326 return failure();
4327 }
4328
4329 return success();
4330}
4331
4332LogicalResult RescaleOp::inferReturnTypeComponents(
4333 MLIRContext *context, ::std::optional<Location> location,
4334 RescaleOp::Adaptor adaptor,
4335 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4336 ShapeAdaptor inputShape(adaptor.getInput().getType());
4337 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4338 return success();
4339}
4340
4341LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4342 MLIRContext *context, ::std::optional<Location> location,
4343 CastFromBlockScaledOp::Adaptor adaptor,
4344 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4345 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4346 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4347 return success();
4348}
4349
4350LogicalResult CastFromBlockScaledOp::verify() {
4351 const Type inputDataType = getInputData().getType();
4352 const Type outputDataType = getResult().getType();
4353 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4354 return emitOpError() << "require compatible shapes for input_data ("
4355 << inputDataType << ") and " << "output_data ("
4356 << outputDataType << ")";
4357
4358 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4359
4360 if (inputDataShape.hasRank()) {
4361 const unsigned int blockSize =
4362 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4363 const int64_t inputDataLastDim =
4364 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4365 if (inputDataLastDim % blockSize != 0)
4366 return emitOpError() << "expect last dimension of input_data ("
4367 << inputDataLastDim
4368 << ") to be divisible by block_size (" << blockSize
4369 << ")";
4370
4371 const Type inputScaleType = getInputScale().getType();
4372 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4373
4374 if (inputScaleShape.hasRank()) {
4375 SmallVector<int64_t> inputDataDims, inputScaleDims;
4376 inputDataShape.getDims(inputDataDims);
4377 inputScaleShape.getDims(inputScaleDims);
4378
4379 if (inputDataDims.size() != inputScaleDims.size() ||
4381 ArrayRef<int64_t>(inputDataDims).drop_back(1),
4382 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4383 return emitOpError()
4384 << "require compatible shapes for input_data (" << inputDataType
4385 << ") and " << "input_scale (" << inputScaleType
4386 << ") except for the last dimension";
4387
4388 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4389 inputScaleDims.back()};
4390 if (ShapedType::isStatic(inputDataLastDim) &&
4391 failed(verifyCompatibleDims(dimsToCheck)))
4392 return emitOpError()
4393 << "expect last dimension of input_scale ("
4394 << inputScaleDims.back()
4395 << ") to be equal to last dimension of input_data / block_size ("
4396 << inputDataDims.back() / blockSize << ")";
4397 }
4398 }
4399
4400 return success();
4401}
4402
4403LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4404 MLIRContext *context, ::std::optional<Location> location,
4405 CastToBlockScaledOp::Adaptor adaptor,
4406 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4407 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4408 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4409 if (!inputShape.hasRank())
4410 return success();
4411
4412 // Calculate output_scale shape if ranked input provided
4413 SmallVector<int64_t> outputScaleShape;
4414 inputShape.getDims(outputScaleShape);
4415 const int64_t lastDimLoc = inputShape.getRank() - 1;
4416 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4417 if (ShapedType::isStatic(lastDimSize)) {
4418 const unsigned int blockSize =
4419 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4420 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4421 }
4422 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4423 return success();
4424}
4425
4426LogicalResult CastToBlockScaledOp::verify() {
4427 const Type inputDataType = getInputData().getType();
4428 const Type outputDataType = getResult(0).getType();
4429 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4430 return emitOpError() << "require compatible shapes for input_data ("
4431 << inputDataType << ") and " << "output_data ("
4432 << outputDataType << ")";
4433
4434 const unsigned int blockSize =
4435 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4436 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4437 if (inputDataShape.hasRank()) {
4438 const int64_t inputDataLastDim =
4439 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4440 if (ShapedType::isStatic(inputDataLastDim) &&
4441 inputDataLastDim % blockSize != 0)
4442 return emitOpError() << "expect last dimension of input_data ("
4443 << inputDataLastDim
4444 << ") to be divisible by block_size (" << blockSize
4445 << ")";
4446 }
4447
4448 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4449 const Type outputScaleType = getResult(1).getType();
4450 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4451 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4452 SmallVector<int64_t> outputDataDims, outputScaleDims;
4453 outputDataShape.getDims(outputDataDims);
4454 outputScaleShape.getDims(outputScaleDims);
4455
4456 if (outputDataDims.size() != outputScaleDims.size() ||
4458 ArrayRef<int64_t>(outputDataDims).drop_back(1),
4459 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4460 return emitOpError() << "require compatible shapes for output_data ("
4461 << outputDataType << ") and " << "output_scale ("
4462 << outputScaleType
4463 << ") except for the last dimension";
4464
4465 const int64_t outputDataLastDim = outputDataDims.back();
4466 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4467 outputScaleDims.back()};
4468 if (ShapedType::isStatic(outputDataLastDim) &&
4469 failed(verifyCompatibleDims(dimsToCheck)))
4470 return emitOpError()
4471 << "expect last dimension of output_scale ("
4472 << outputScaleDims.back()
4473 << ") to be equal to last dimension of output_data / block_size ("
4474 << outputDataDims.back() / blockSize << ")";
4475 }
4476
4477 return success();
4478}
4479
4480LogicalResult IfOp::inferReturnTypeComponents(
4481 MLIRContext *context, ::std::optional<Location> location,
4482 IfOp::Adaptor adaptor,
4483 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4484 llvm::SmallVector<tosa::YieldOp> yieldOps;
4485 for (Region *region : adaptor.getRegions()) {
4486 for (auto &block : *region)
4487 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4488 yieldOps.push_back(returnOp);
4489 }
4490
4491 if (yieldOps.empty())
4492 return failure();
4493
4494 // Get the initial type information for the yield op.
4495 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4496 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4497 for (auto operand : yieldOps.front().getOperands()) {
4498 resultKnowledge.push_back(
4499 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4500 }
4501
4502 for (auto yieldOp : yieldOps) {
4503 if (resultKnowledge.size() != yieldOp.getNumOperands())
4504 return failure();
4505
4506 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4507 int32_t index = it.index();
4508 auto meet = ValueKnowledge::meet(
4509 resultKnowledge[index],
4510 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4511 if (!meet)
4512 continue;
4513 resultKnowledge[index] = meet;
4514 }
4515 }
4516
4517 for (const ValueKnowledge &result : resultKnowledge) {
4518 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4519 }
4520
4521 return success();
4522}
4523
4524LogicalResult WhileOp::inferReturnTypeComponents(
4525 MLIRContext *context, ::std::optional<Location> location,
4526 WhileOp::Adaptor adaptor,
4527 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4528 llvm::SmallVector<tosa::YieldOp> yieldOps;
4529 for (auto &block : adaptor.getBodyGraph())
4530 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4531 yieldOps.push_back(returnOp);
4532
4533 // TOSA's while must have a tosa.yield as its terminator. If not found this
4534 // tosa.while is invalid.
4535 if (yieldOps.empty())
4536 return failure();
4537
4538 // Get the initial type information from the operand types.
4539 llvm::SmallVector<ValueKnowledge> resultKnowledge;
4540 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4541 for (auto operand : yieldOps.front().getOperands()) {
4542 resultKnowledge.push_back(
4543 ValueKnowledge::getKnowledgeFromType(operand.getType()));
4544 }
4545
4546 for (auto yieldOp : yieldOps) {
4547 if (resultKnowledge.size() != yieldOp.getNumOperands())
4548 return failure();
4549
4550 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4551 int32_t index = it.index();
4552 if (auto meet = ValueKnowledge::meet(
4553 resultKnowledge[index],
4554 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4555 resultKnowledge[index] = meet;
4556 }
4557 }
4558 }
4559
4560 for (const ValueKnowledge &result : resultKnowledge) {
4561 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4562 }
4563
4564 return success();
4565}
4566
4567std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4568 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4569 return llvm::to_vector<4>(vt.getShape());
4570 return std::nullopt;
4571}
4572
4574 Block::BlockArgListType blocksArgs,
4575 ValueRange initializers,
4576 StringRef prefix = "") {
4577 assert(blocksArgs.size() == initializers.size() &&
4578 "expected same length of arguments and initializers");
4579 if (initializers.empty())
4580 return;
4581
4582 parser << prefix << '(';
4583 llvm::interleaveComma(
4584 llvm::zip(blocksArgs, initializers), parser,
4585 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4586 parser << ")";
4587}
4588
4589// parse and print of IfOp refer to the implementation of SCF dialect.
4590ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4591 // Create the regions for 'then'.
4592 result.regions.reserve(2);
4593 Region *thenRegion = result.addRegion();
4594 Region *elseRegion = result.addRegion();
4595
4596 OpAsmParser::UnresolvedOperand cond;
4597
4598 if (parser.parseOperand(cond))
4599 return failure();
4600
4601 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4602 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4603
4604 // Parse the optional block arguments
4605 OptionalParseResult listResult =
4606 parser.parseOptionalAssignmentList(regionArgs, operands);
4607 if (listResult.has_value() && failed(listResult.value()))
4608 return failure();
4609
4610 // Parse a colon.
4611 if (failed(parser.parseColon()))
4612 return parser.emitError(parser.getCurrentLocation(),
4613 "expected type for condition operand");
4614
4615 // Parse the type of the condition operand
4616 Type condType;
4617 if (failed(parser.parseType(condType)))
4618 return parser.emitError(parser.getCurrentLocation(),
4619 "expected type for condition operand");
4620
4621 // Resolve operand with provided type
4622 if (failed(parser.resolveOperand(cond, condType, result.operands)))
4623 return failure();
4624
4625 // Parse optional block arg types
4626 if (listResult.has_value()) {
4627 FunctionType functionType;
4628
4629 if (failed(parser.parseType(functionType)))
4630 return parser.emitError(parser.getCurrentLocation())
4631 << "expected list of types for block arguments "
4632 << "followed by arrow type and list of return types";
4633
4634 result.addTypes(functionType.getResults());
4635
4636 if (functionType.getNumInputs() != operands.size()) {
4637 return parser.emitError(parser.getCurrentLocation())
4638 << "expected as many input types as operands " << "(expected "
4639 << operands.size() << " got " << functionType.getNumInputs()
4640 << ")";
4641 }
4642
4643 // Resolve input operands.
4644 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4645 parser.getCurrentLocation(),
4646 result.operands)))
4647 return failure();
4648 } else {
4649 // Parse optional results type list.
4650 if (parser.parseOptionalArrowTypeList(result.types))
4651 return failure();
4652 }
4653
4654 // Parse the 'then' region.
4655 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4656 return failure();
4657
4658 // If we find an 'else' keyword then parse the 'else' region.
4659 if (!parser.parseOptionalKeyword("else")) {
4660 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4661 return failure();
4662 }
4663
4664 // Parse the optional attribute list.
4665 if (parser.parseOptionalAttrDict(result.attributes))
4666 return failure();
4667 return success();
4668}
4669
4670void IfOp::print(OpAsmPrinter &p) {
4671 p << " " << getCondition();
4672
4673 printInitializationList(p, getThenGraph().front().getArguments(),
4674 getInputList(), " ");
4675 p << " : ";
4676 p << getCondition().getType();
4677
4678 if (!getInputList().empty()) {
4679 p << " (";
4680 llvm::interleaveComma(getInputList().getTypes(), p);
4681 p << ")";
4682 }
4683 p.printArrowTypeList(getResultTypes());
4684 p << " ";
4685
4686 p.printRegion(getThenGraph());
4687
4688 // Print the 'else' regions if it exists and has a block.
4689 auto &elseRegion = getElseGraph();
4690 if (!elseRegion.empty()) {
4691 p << " else ";
4692 p.printRegion(elseRegion);
4693 }
4694
4695 p.printOptionalAttrDict((*this)->getAttrs());
4696}
4697
4698LogicalResult IfOp::verify() {
4699 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4700 "'then_graph' arguments", getInputList(),
4701 "'input_list'")
4702 .failed())
4703 return failure();
4704
4705 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4706 "'else_graph' arguments", getInputList(),
4707 "'input_list'")
4708 .failed())
4709 return failure();
4710
4711 // MLIR will verify the absence of the terminator for us if otherwise.
4712 if (getThenGraph().front().mightHaveTerminator()) {
4713 auto thenYield =
4714 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4715 if (thenYield && errorIfTypeOrShapeMismatch(
4716 *this, thenYield.getInputs(), "'then_graph' results",
4717 getOutputList(), "'output_list'")
4718 .failed())
4719 return failure();
4720 }
4721
4722 // MLIR will verify the absence of the terminator for us if otherwise.
4723 if (getElseGraph().front().mightHaveTerminator()) {
4724 auto elseYield =
4725 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4726 if (elseYield && errorIfTypeOrShapeMismatch(
4727 *this, elseYield.getInputs(), "'else_graph' results",
4728 getOutputList(), "'output_list'")
4729 .failed())
4730 return failure();
4731 }
4732
4733 auto condType = getCondition().getType();
4734 if (errorIfShapeNotSizeOne(*this, condType).failed())
4735 return emitOpError() << "'condition' must be a size 1 tensor, got "
4736 << condType;
4737
4738 return success();
4739}
4740
4741LogicalResult WhileOp::verify() {
4742 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4743 getOutputList(), "'output_list'")
4744 .failed())
4745 return failure();
4746
4747 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4748 "'cond_graph' arguments", getInputList(),
4749 "'input_list'")
4750 .failed())
4751 return failure();
4752
4753 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4754 "'body_graph' arguments", getInputList(),
4755 "'input_list'")
4756 .failed())
4757 return failure();
4758
4759 if (getBodyGraph().front().mightHaveTerminator()) {
4760 auto bodyYield =
4761 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4762 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4763 "'body_graph' results",
4764 getInputList(), "'input_list'")
4765 .failed())
4766 return failure();
4767 }
4768
4769 // Condition block output must be a single element tensor with a single bool
4770 // value.
4771 if (!getCondGraph().front().mightHaveTerminator())
4772 return success();
4773
4774 auto condYield =
4775 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4776 if (!condYield)
4777 return success();
4778
4779 if (condYield.getInputs().size() != 1)
4780 return emitOpError() << "require 'cond_graph' only have one result";
4781
4782 auto condOutType = condYield.getInputs()[0].getType();
4783 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4784 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4785 << condOutType;
4786
4787 if (!getElementTypeOrSelf(condOutType).isInteger(1))
4788 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4789 << condOutType;
4790
4791 return success();
4792}
4793
4794LogicalResult ReverseOp::verify() {
4795 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4796 /* outType = */ getOutput().getType())
4797 .failed())
4798 return failure();
4799 TensorType inputType = getInput1().getType();
4800 TensorType outputType = getOutput().getType();
4801 int32_t reverseAxis = getAxis();
4802
4803 if (reverseAxis < 0)
4804 return emitOpError("expected non-negative reverse axis");
4805 if (inputType.hasRank()) {
4806 int64_t inputRank = inputType.getRank();
4807 // We allow for a special case where the input/output shape has rank 0 and
4808 // axis is also 0.
4809 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4810 return emitOpError("expect input tensor rank (")
4811 << inputRank << ") to be larger than reverse axis (" << reverseAxis
4812 << ")";
4813 }
4814 if (outputType.hasRank()) {
4815 int64_t outputRank = outputType.getRank();
4816 if (inputType.hasRank() && outputRank != inputType.getRank())
4817 return emitOpError(
4818 "expect output tensor rank to be equal to input tensor rank");
4819 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4820 return emitOpError("expect output tensor rank (")
4821 << outputRank << ") to be larger than reverse axis ("
4822 << reverseAxis << ")";
4823 }
4824 return success();
4825}
4826
4827LogicalResult tosa::SelectOp::verify() {
4828 // verify input2 and input3 have same element type as output
4829 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4830 /* outType = */ getOutput().getType())
4831 .failed() ||
4832 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4833 /* outType = */ getOutput().getType())
4834 .failed()) {
4835 return failure();
4836 }
4837 // verify input1 has element type of bool
4838 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4839 if (!predicateType) {
4840 return emitOpError("expect shaped tensor for input1, got ")
4841 << getInput1().getType();
4842 }
4843 auto predicateElementType = predicateType.getElementType();
4844 if (!predicateElementType.isInteger(1)) {
4845 return emitOpError("expect element type of bool for input1, got ")
4846 << predicateElementType;
4847 }
4848
4849 return success();
4850}
4851
4852LogicalResult tosa::VariableReadOp::verify() {
4853 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4854 .failed())
4855 return failure();
4856
4857 return success();
4858}
4859
4860LogicalResult tosa::VariableWriteOp::verify() {
4861 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4862 .failed())
4863 return failure();
4864
4865 return success();
4866}
4867
4868// parse and print of WhileOp refer to the implementation of SCF dialect.
4869ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4870 SmallVector<OpAsmParser::Argument, 4> regionArgs;
4871 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
4872 Region *cond = result.addRegion();
4873 Region *body = result.addRegion();
4874
4875 OptionalParseResult listResult =
4876 parser.parseOptionalAssignmentList(regionArgs, operands);
4877 if (listResult.has_value() && failed(listResult.value()))
4878 return failure();
4879
4880 FunctionType functionType;
4881 SMLoc typeLoc = parser.getCurrentLocation();
4882 if (failed(parser.parseColonType(functionType)))
4883 return failure();
4884
4885 result.addTypes(functionType.getResults());
4886
4887 if (functionType.getNumInputs() != operands.size()) {
4888 return parser.emitError(typeLoc)
4889 << "expected as many input types as operands " << "(expected "
4890 << operands.size() << " got " << functionType.getNumInputs() << ")";
4891 }
4892
4893 // Resolve input operands.
4894 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4895 parser.getCurrentLocation(),
4896 result.operands)))
4897 return failure();
4898
4899 // Propagate the types into the region arguments.
4900 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4901 regionArgs[i].type = functionType.getInput(i);
4902
4903 return failure(parser.parseRegion(*cond, regionArgs) ||
4904 parser.parseKeyword("do") || parser.parseRegion(*body) ||
4905 parser.parseOptionalAttrDictWithKeyword(result.attributes));
4906}
4907
4908void WhileOp::print(OpAsmPrinter &parser) {
4909 printInitializationList(parser, getCondGraph().front().getArguments(),
4910 getInputList(), " ");
4911 parser << " : ";
4912 parser.printFunctionalType(getInputList().getTypes(),
4913 getResults().getTypes());
4914 parser << ' ';
4915 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4916 parser << " do ";
4917 parser.printRegion(getBodyGraph());
4918 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4919}
4920
4921// Create a rank-1 const tensor for zero point of the source tensor.
4922std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4923 Location loc,
4924 Type srcElemType,
4925 int64_t zp) {
4926 srcElemType = getStorageElementTypeOrSelf(srcElemType);
4927 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4928 if (llvm::isa<FloatType>(srcElemType)) {
4929 auto zpAttr = DenseElementsAttr::get(
4930 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4931 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4932 }
4933 if (llvm::isa<IntegerType>(srcElemType)) {
4934 auto zpAttr =
4935 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4936 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4937 }
4938 llvm::errs() << "zero point is not allowed for unsupported data types\n";
4939 return std::nullopt;
4940}
4941
4942//===----------------------------------------------------------------------===//
4943// TOSA Shape and Shape Operators Helper functions.
4944//===----------------------------------------------------------------------===//
4945
4947 return mlir::isa<tosa::shapeType>(t);
4948}
4949
4950LogicalResult
4951mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
4952 int rank) {
4953 if (rank < 0)
4954 return emitError() << "invalid rank (must be >= 0): " << rank;
4955 return success();
4956}
4957
4959 for (auto v : op->getOperands()) {
4960 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4961 Operation *definingOp = v.getDefiningOp();
4962 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4963 return op->emitOpError("shape operand is not compile time resolvable");
4964 }
4965 }
4966 }
4967 return success();
4968}
4969
4970LogicalResult
4972 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
4973 return failure();
4974
4975 // delegate function that returns rank of shape type
4976 auto getRank = [](const Type type) {
4977 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4978 };
4979 auto operandTypes = op->getOperandTypes();
4980 auto resultTypes = op->getResultTypes();
4981
4982 auto rank = getRank(*op->getOperandTypes().begin());
4983 for (auto type : operandTypes) {
4984 if (getRank(type) != rank) {
4985 return op->emitOpError("operands don't have matching ranks");
4986 }
4987 }
4988 for (auto type : resultTypes) {
4989 if (getRank(type) != rank) {
4990 return op->emitOpError("result shape has different rank than operands");
4991 }
4992 }
4993 return success();
4994}
4995
4996//===----------------------------------------------------------------------===//
4997// TOSA Shape Operators verify functions.
4998//===----------------------------------------------------------------------===//
4999
5000LogicalResult tosa::ConstShapeOp::verify() {
5001 // check one dimensional rank
5002 auto valuesRank = getValues().getType().getRank();
5003 if (valuesRank != 1)
5004 return emitOpError("expect elements in attribute values with rank 1");
5005 // check that number of elements in values attr equal to rank of result shape
5006 auto count = getValues().getNumElements();
5007 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
5008 if (count != rank && (count != 1 || rank != 0)) {
5009 return emitOpError("expect number of elements in attribute values (")
5010 << count << ") to be equal to the rank (" << rank
5011 << ") for the result shape type";
5012 }
5013 return success();
5014}
5015
5016LogicalResult tosa::DimOp::verify() {
5017 const tosa::shapeType outShapeType =
5018 cast<tosa::shapeType>(getResult().getType());
5019 if (outShapeType.getRank() != 1)
5020 return emitOpError("expect output shape type to contain one element, got ")
5021 << outShapeType;
5022
5023 const ShapeAdaptor inputType(getInput1().getType());
5024 if (inputType.hasRank()) {
5025 const int64_t inputRank = inputType.getRank();
5026 const int64_t axis = getAxisAttr().getInt();
5027 if (axis < 0 || axis >= inputRank)
5028 return emitOpError("expect axis to be in the range [0, ")
5029 << inputRank << "), got " << axis;
5030 }
5031 return success();
5032}
5033
5034LogicalResult tosa::ConcatShapeOp::verify() {
5035 const tosa::shapeType outShapeType =
5036 cast<tosa::shapeType>(getResult().getType());
5037 const int64_t outputRank = outShapeType.getRank();
5038 const Operation::operand_range inputList = getInput();
5039
5040 if (inputList.size() == 0)
5041 return emitOpError("requires at least one input shape");
5042
5043 if (llvm::any_of(inputList, [](Value v) {
5044 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
5045 }))
5046 return emitOpError("requires all inputs shapes have a rank greater than 0");
5047
5048 const int64_t inputsRank =
5049 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
5050 const tosa::shapeType inShapeType =
5051 cast<tosa::shapeType>(input.getType());
5052 return acc + inShapeType.getRank();
5053 });
5054 if (outputRank != inputsRank)
5055 return emitOpError("requires output shape rank to be equal to the sum of "
5056 "the input shape ranks (")
5057 << inputsRank << "), got " << outputRank;
5058
5059 return success();
5060}
5061
5062LogicalResult tosa::SliceShapeOp::verify() {
5063 std::optional<int32_t> start;
5064 DenseIntElementsAttr startAttr;
5065 if (matchPattern(getStart(), m_Constant(&startAttr)))
5066 start = startAttr.getValues<int32_t>()[0];
5067 if (start && start.value() < 0)
5068 return emitOpError("expected non-negative start index, got ")
5069 << start.value();
5070
5071 std::optional<int32_t> size;
5072 DenseIntElementsAttr sizeAttr;
5073 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
5074 size = sizeAttr.getValues<int32_t>()[0];
5075 if (size && size.value() <= 0)
5076 return emitOpError("expected positive size, got ") << size.value();
5077
5078 if (!size)
5079 return success();
5080
5081 const tosa::shapeType outShapeType =
5082 cast<tosa::shapeType>(getResult().getType());
5083 const int64_t outputRank = outShapeType.getRank();
5084 if (outputRank != size)
5085 return emitOpError(
5086 "expected output type size to be equal to size attribute, got ")
5087 << outputRank << " vs " << size.value();
5088
5089 if (!start)
5090 return success();
5091
5092 const tosa::shapeType inShapeType =
5093 cast<tosa::shapeType>(getInput().getType());
5094 const int64_t inputRank = inShapeType.getRank();
5095 const int64_t sliceSize = start.value() + size.value();
5096 if (sliceSize > inputRank)
5097 return emitOpError("expected start + size to be less than or equal to "
5098 "input shape rank (")
5099 << inputRank << "), got " << sliceSize;
5100
5101 return success();
5102}
5103
5104//===----------------------------------------------------------------------===//
5105// TOSA Attribute Definitions.
5106//===----------------------------------------------------------------------===//
5107
5108#define GET_ATTRDEF_CLASSES
5109#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5110
5111//===----------------------------------------------------------------------===//
5112// TOSA Type Definitions.
5113//===----------------------------------------------------------------------===//
5114#define GET_TYPEDEF_CLASSES
5115#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5116
5117//===----------------------------------------------------------------------===//
5118// TOSA Operator Definitions.
5119//===----------------------------------------------------------------------===//
5120
5121#define GET_OP_CLASSES
5122#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1326
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1015
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3705
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:138
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3253
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:585
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:979
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3278
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:680
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:988
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3468
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1440
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:568
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1454
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:643
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3303
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3371
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:2750
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1304
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:834
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2680
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:628
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:785
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3359
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:574
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3269
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1498
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1400
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1280
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1355
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:937
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1484
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:2708
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1081
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1584
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static void updateIfDynamic(int64_t &current, int64_t candidate)
Definition TosaOps.cpp:3507
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3600
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3629
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3574
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
Definition TosaOps.cpp:3571
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3520
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3535
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
Definition TosaOps.cpp:3517
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3553
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:3670
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
Definition TosaOps.cpp:3650
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:3653
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:3690
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OperandRange operand_range
Definition Operation.h:400
operand_type_range getOperandTypes()
Definition Operation.h:426
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:4971
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:4958
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:227
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:252
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:620
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:4922
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:4946
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:605
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45